diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 637593b9de..b92d4c35a8 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -7,7 +7,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml new file mode 100644 index 0000000000..c57da7cb5f --- /dev/null +++ b/.github/actions/setup-web/action.yml @@ -0,0 +1,33 @@ +name: Setup Web Environment +description: Setup pnpm, Node.js, and install web dependencies. + +inputs: + node-version: + description: Node.js version to use + required: false + default: "22" + install-dependencies: + description: Whether to install web dependencies after setting up Node.js + required: false + default: "true" + +runs: + using: composite + steps: + - name: Install pnpm + uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 + with: + node-version: ${{ inputs.node-version }} + cache: pnpm + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Install dependencies + if: ${{ inputs.install-dependencies == 'true' }} + shell: bash + run: pnpm --dir web install --frozen-lockfile diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 78f6eefd0d..5715b1e83f 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -19,19 +19,3 @@ updates: uv-dependencies: patterns: - "*" - - package-ecosystem: "npm" - directory: "/web" - schedule: - interval: "weekly" - open-pull-requests-limit: 2 - groups: - storybook: - patterns: - - "storybook" - - "@storybook/*" - npm-dependencies: - patterns: - - "*" - exclude-patterns: - - "storybook" - - "@storybook/*" diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml new file mode 100644 index 0000000000..c0d1818691 --- /dev/null +++ b/.github/workflows/anti-slop.yml @@ -0,0 +1,19 @@ +name: Anti-Slop PR Check + +on: + pull_request_target: + types: [opened, edited, synchronize] + +permissions: + pull-requests: write + contents: read + +jobs: + anti-slop: + runs-on: ubuntu-latest + steps: + - uses: peakoss/anti-slop@v0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + close-pr: false + failure-add-pr-labels: "needs-revision" diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 52e3272f99..03f6917dca 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -22,12 +22,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -51,7 +51,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@v2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 4571fd1cd1..4b48e741df 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -12,22 +12,34 @@ jobs: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Check Docker Compose inputs id: docker-compose-changes - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | docker/generate_docker_compose docker/.env.example docker/docker-compose-template.yaml docker/docker-compose.yaml - - uses: actions/setup-python@v6 + - name: Check web inputs + id: web-changes + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + with: + files: | + web/** + - name: Check api inputs + id: api-changes + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + with: + files: | + api/** + - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.11" - - uses: astral-sh/setup-uv@v7 + - uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 - name: Generate Docker Compose if: steps.docker-compose-changes.outputs.any_changed == 'true' @@ -35,7 +47,8 @@ jobs: cd docker ./generate_docker_compose - - run: | + - if: steps.api-changes.outputs.any_changed == 'true' + run: | cd api uv sync --dev # fmt first to avoid line too long @@ -46,11 +59,13 @@ jobs: uv run ruff format .. - name: count migration progress + if: steps.api-changes.outputs.any_changed == 'true' run: | cd api ./cnt_base.sh - name: ast-grep + if: steps.api-changes.outputs.any_changed == 'true' run: | # ast-grep exits 1 if no matches are found; allow idempotent runs. uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true @@ -84,4 +99,16 @@ jobs: run: | uvx --python 3.13 mdformat . --exclude ".agents/skills/**" - - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 + - name: Setup web environment + if: steps.web-changes.outputs.any_changed == 'true' + uses: ./.github/actions/setup-web + with: + node-version: "24" + + - name: ESLint autofix + if: steps.web-changes.outputs.any_changed == 'true' + run: | + cd web + pnpm eslint --concurrency=2 --prune-suppressions --quiet || true + + - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index ac7f3a6b48..94466d151c 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -53,26 +53,26 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Login to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - name: Extract metadata for Docker id: meta - uses: docker/metadata-action@v5 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0 with: images: ${{ env[matrix.image_name_env] }} - name: Build Docker image id: build - uses: docker/build-push-action@v6 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: context: "{{defaultContext}}:${{ matrix.context }}" platforms: ${{ matrix.platform }} @@ -91,7 +91,7 @@ jobs: touch "/tmp/digests/${sanitized_digest}" - name: Upload digest - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }} path: /tmp/digests/* @@ -113,21 +113,21 @@ jobs: context: "web" steps: - name: Download digests - uses: actions/download-artifact@v7 + uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0 with: path: /tmp/digests pattern: digests-${{ matrix.context }}-* merge-multiple: true - name: Login to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} - name: Extract metadata for Docker id: meta - uses: docker/metadata-action@v5 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0 with: images: ${{ env[matrix.image_name_env] }} tags: | diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index e20cf9850b..84a506a325 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -13,13 +13,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: enable-cache: true python-version: "3.12" @@ -40,7 +40,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@v2.0.2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -63,13 +63,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: enable-cache: true python-version: "3.12" @@ -94,7 +94,7 @@ jobs: sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@v2.0.2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/deploy-agent-dev.yml b/.github/workflows/deploy-agent-dev.yml index dd759f7ba5..cd5fe9242e 100644 --- a/.github/workflows/deploy-agent-dev.yml +++ b/.github/workflows/deploy-agent-dev.yml @@ -19,7 +19,7 @@ jobs: github.event.workflow_run.head_branch == 'deploy/agent-dev' steps: - name: Deploy to server - uses: appleboy/ssh-action@v1 + uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5 with: host: ${{ secrets.AGENT_DEV_SSH_HOST }} username: ${{ secrets.SSH_USER }} diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 38fa0b9a7f..954537663a 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -16,7 +16,7 @@ jobs: github.event.workflow_run.head_branch == 'deploy/dev' steps: - name: Deploy to server - uses: appleboy/ssh-action@v1 + uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5 with: host: ${{ secrets.SSH_HOST }} username: ${{ secrets.SSH_USER }} diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml index a3fd52afc6..c6f1cc7e6f 100644 --- a/.github/workflows/deploy-hitl.yml +++ b/.github/workflows/deploy-hitl.yml @@ -16,7 +16,7 @@ jobs: github.event.workflow_run.head_branch == 'build/feat/hitl' steps: - name: Deploy to server - uses: appleboy/ssh-action@v1 + uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5 with: host: ${{ secrets.HITL_SSH_HOST }} username: ${{ secrets.SSH_USER }} diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index cadc1b5507..340b380dc9 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -32,13 +32,13 @@ jobs: context: "web" steps: - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - name: Build Docker Image - uses: docker/build-push-action@v6 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: push: false context: "{{defaultContext}}:${{ matrix.context }}" diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 06782b53c1..278e10bc04 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -9,6 +9,6 @@ jobs: pull-requests: write runs-on: ubuntu-latest steps: - - uses: actions/labeler@v6 + - uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 with: sync-labels: true diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index d6653de950..ef2e3c7bb4 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -27,8 +27,8 @@ jobs: vdb-changed: ${{ steps.changes.outputs.vdb }} migration-changed: ${{ steps.changes.outputs.migration }} steps: - - uses: actions/checkout@v6 - - uses: dorny/paths-filter@v3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 id: changes with: filters: | @@ -39,6 +39,7 @@ jobs: web: - 'web/**' - '.github/workflows/web-tests.yml' + - '.github/actions/setup-web/**' vdb: - 'api/core/rag/datasource/**' - 'docker/**' diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index f9fbcba465..0278e1e0d3 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -21,7 +21,7 @@ jobs: if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} steps: - name: Download pyrefly diff artifact - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -49,7 +49,7 @@ jobs: run: unzip -o pyrefly_diff.zip - name: Post comment - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index 14338e85b3..cceaf58789 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -17,12 +17,12 @@ jobs: pull-requests: write steps: - name: Checkout PR branch - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: enable-cache: true @@ -55,7 +55,7 @@ jobs: echo ${{ github.event.pull_request.number }} > pr_number.txt - name: Upload pyrefly diff - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: pyrefly_diff path: | @@ -64,7 +64,7 @@ jobs: - name: Comment PR with pyrefly diff if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index b15c26a096..c21331ec0d 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -16,6 +16,6 @@ jobs: runs-on: ubuntu-latest steps: - name: Check title - uses: amannn/action-semantic-pull-request@v6.1.1 + uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index b6df1d7e93..5cf52daed2 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -18,7 +18,7 @@ jobs: pull-requests: write steps: - - uses: actions/stale@v10 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 with: days-before-issue-stale: 15 days-before-issue-close: 3 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index eb13c3d096..4168f890f5 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -19,13 +19,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | api/** @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: enable-cache: false python-version: "3.12" @@ -67,36 +67,22 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** .github/workflows/style.yml + .github/actions/setup-web/** - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup NodeJS - uses: actions/setup-node@v6 + - name: Setup web environment if: steps.changed-files.outputs.any_changed == 'true' - with: - node-version: 22 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Web dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm install --frozen-lockfile + uses: ./.github/actions/setup-web - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' @@ -134,14 +120,14 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | **.sh @@ -152,7 +138,7 @@ jobs: .editorconfig - name: Super-linter - uses: super-linter/super-linter/slim@v8 + uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index d9a1168636..3fc351c0c2 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -21,12 +21,12 @@ jobs: working-directory: sdks/nodejs-client steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Use Node.js - uses: actions/setup-node@v6 + uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 with: node-version: 22 cache: '' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index b431c36a8b..ff07313ebe 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} @@ -48,18 +48,10 @@ jobs: git config --global user.name "github-actions[bot]" git config --global user.email "github-actions[bot]@users.noreply.github.com" - - name: Install pnpm - uses: pnpm/action-setup@v4 + - name: Setup web environment + uses: ./.github/actions/setup-web with: - package_json_file: web/package.json - run_install: false - - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: 22 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml + install-dependencies: "false" - name: Detect changed files and generate diff id: detect_changes @@ -130,7 +122,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml index 66a29453b4..1caaddd47a 100644 --- a/.github/workflows/trigger-i18n-sync.yml +++ b/.github/workflows/trigger-i18n-sync.yml @@ -21,7 +21,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 @@ -59,7 +59,7 @@ jobs: - name: Trigger i18n sync workflow if: steps.detect.outputs.has_changes == 'true' - uses: peter-evans/repository-dispatch@v3 + uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1 with: token: ${{ secrets.GITHUB_TOKEN }} event-type: i18n-sync diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 7735afdaca..8cb7db7601 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -19,19 +19,19 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Free Disk Space - uses: endersonmenezes/free-disk-space@v3 + uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2 with: remove_dotnet: true remove_haskell: true remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -60,7 +60,7 @@ jobs: # tiflash - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase) - uses: hoverkraft-tech/compose-action@v2.0.2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.yaml diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 659620b2a9..33e9170b02 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -26,32 +26,19 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: 22 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Install dependencies - run: pnpm install --frozen-lockfile + - name: Setup web environment + uses: ./.github/actions/setup-web - name: Run tests run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage - name: Upload blob report if: ${{ !cancelled() }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: blob-report-${{ matrix.shardIndex }} path: web/.vitest-reports/* @@ -70,28 +57,15 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: 22 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Install dependencies - run: pnpm install --frozen-lockfile + - name: Setup web environment + uses: ./.github/actions/setup-web - name: Download blob reports - uses: actions/download-artifact@v6 + uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0 with: path: web/.vitest-reports pattern: blob-report-* @@ -419,7 +393,7 @@ jobs: - name: Upload Coverage Artifact if: steps.coverage-summary.outputs.has_coverage == 'true' - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: web-coverage-report path: web/coverage @@ -435,36 +409,22 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** .github/workflows/web-tests.yml + .github/actions/setup-web/** - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup NodeJS - uses: actions/setup-node@v6 + - name: Setup web environment if: steps.changed-files.outputs.any_changed == 'true' - with: - node-version: 22 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Web dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm install --frozen-lockfile + uses: ./.github/actions/setup-web - name: Web build check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index 700b815c3b..c3e2c50c52 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -37,7 +37,7 @@ "-c", "1", "-Q", - "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution", + "dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution", "--loglevel", "INFO" ], diff --git a/Makefile b/Makefile index 0aff26b3e5..55871c86a7 100644 --- a/Makefile +++ b/Makefile @@ -68,8 +68,9 @@ lint: @echo "✅ Linting complete" type-check: - @echo "📝 Running type checks (basedpyright + mypy)..." + @echo "📝 Running type checks (basedpyright + pyrefly + mypy)..." @./dev/basedpyright-check $(PATH_TO_CHECK) + @./dev/pyrefly-check-local @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . @echo "✅ Type checks complete" @@ -131,7 +132,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" - @echo " make type-check - Run type checks (basedpyright, mypy)" + @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/README.md b/README.md index 90961a5346..bef8f6b782 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ README in বাংলা

-Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. +Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features: ## Quick start @@ -133,7 +133,7 @@ Star Dify on GitHub and be instantly notified of new releases. ### Custom configurations -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). #### Customizing Suggested Questions diff --git a/api/.importlinter b/api/.importlinter index e4536b1f10..5c0a6e1288 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -44,9 +44,7 @@ forbidden_modules = allow_indirect_imports = True ignore_imports = dify_graph.nodes.agent.agent_node -> extensions.ext_database - dify_graph.nodes.llm.file_saver -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis @@ -112,9 +110,7 @@ ignore_imports = dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.tool.tool_node -> models dify_graph.nodes.agent.agent_node -> models.model - dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy dify_graph.nodes.llm.node -> core.helper.code_executor dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output @@ -135,9 +131,7 @@ ignore_imports = dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager dify_graph.nodes.tool.tool_node -> core.tools.errors dify_graph.nodes.agent.agent_node -> extensions.ext_database - dify_graph.nodes.llm.file_saver -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.nodes.agent.agent_node -> models dify_graph.nodes.llm.node -> models.model dify_graph.nodes.agent.agent_node -> services diff --git a/api/AGENTS.md b/api/AGENTS.md index 13adb42276..d43d2528b8 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co - Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values). - Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason. +- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`. +- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional). +- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown. + +```python +from datetime import datetime +from typing import NotRequired, TypedDict + + +class UserProfile(TypedDict): + user_id: str + email: str + created_at: datetime + nickname: NotRequired[str] +``` + - For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance: ```python diff --git a/api/commands.py b/api/commands.py deleted file mode 100644 index 75b17df78e..0000000000 --- a/api/commands.py +++ /dev/null @@ -1,2670 +0,0 @@ -import base64 -import datetime -import json -import logging -import secrets -import time -from typing import Any - -import click -import sqlalchemy as sa -from flask import current_app -from pydantic import TypeAdapter -from sqlalchemy import select -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import sessionmaker - -from configs import dify_config -from constants.languages import languages -from core.helper import encrypter -from core.plugin.entities.plugin_daemon import CredentialType -from core.plugin.impl.plugin import PluginInstaller -from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.models.document import ChildDocument, Document -from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params -from events.app_event import app_was_created -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from extensions.ext_storage import storage -from extensions.storage.opendal_storage import OpenDALStorage -from extensions.storage.storage_type import StorageType -from libs.db_migration_lock import DbMigrationAutoRenewLock -from libs.helper import email as email_validate -from libs.password import hash_password, password_pattern, valid_password -from libs.rsa import generate_key_pair -from models import Tenant -from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment -from models.dataset import Document as DatasetDocument -from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile -from models.oauth import DatasourceOauthParamConfig, DatasourceProvider -from models.provider import Provider, ProviderModel -from models.provider_ids import DatasourceProviderID, ToolProviderID -from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding -from models.tools import ToolOAuthSystemClient -from services.account_service import AccountService, RegisterService, TenantService -from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs -from services.plugin.data_migration import PluginDataMigration -from services.plugin.plugin_migration import PluginMigration -from services.plugin.plugin_service import PluginService -from services.retention.conversation.messages_clean_policy import create_message_clean_policy -from services.retention.conversation.messages_clean_service import MessagesCleanService -from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup -from tasks.remove_app_and_related_data_task import delete_draft_variables_batch - -logger = logging.getLogger(__name__) - -DB_UPGRADE_LOCK_TTL_SECONDS = 60 - - -@click.command("reset-password", help="Reset the account password.") -@click.option("--email", prompt=True, help="Account email to reset password for") -@click.option("--new-password", prompt=True, help="New password") -@click.option("--password-confirm", prompt=True, help="Confirm new password") -def reset_password(email, new_password, password_confirm): - """ - Reset password of owner account - Only available in SELF_HOSTED mode - """ - if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style("Passwords do not match.", fg="red")) - return - normalized_email = email.strip().lower() - - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) - - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return - - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return - - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() - - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) - - -@click.command("reset-email", help="Reset the account email.") -@click.option("--email", prompt=True, help="Current account email") -@click.option("--new-email", prompt=True, help="New email") -@click.option("--email-confirm", prompt=True, help="Confirm new email") -def reset_email(email, new_email, email_confirm): - """ - Replace account email - :return: - """ - if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style("New emails do not match.", fg="red")) - return - normalized_new_email = new_email.strip().lower() - - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) - - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return - - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return - - account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) - - -@click.command( - "reset-encrypt-key-pair", - help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " - "After the reset, all LLM credentials will become invalid, " - "requiring re-entry." - "Only support SELF_HOSTED mode.", -) -@click.confirmation_option( - prompt=click.style( - "Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red" - ) -) -def reset_encrypt_key_pair(): - """ - Reset the encrypted key pair of workspace for encrypt LLM credentials. - After the reset, all LLM credentials will become invalid, requiring re-entry. - Only support SELF_HOSTED mode. - """ - if dify_config.EDITION != "SELF_HOSTED": - click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) - return - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - tenants = session.query(Tenant).all() - for tenant in tenants: - if not tenant: - click.echo(click.style("No workspaces found. Run /install first.", fg="red")) - return - - tenant.encrypt_public_key = generate_key_pair(tenant.id) - - session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() - session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() - - click.echo( - click.style( - f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", - fg="green", - ) - ) - - -@click.command("vdb-migrate", help="Migrate vector db.") -@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") -def vdb_migrate(scope: str): - if scope in {"knowledge", "all"}: - migrate_knowledge_vector_database() - if scope in {"annotation", "all"}: - migrate_annotation_vector_database() - - -def migrate_annotation_vector_database(): - """ - Migrate annotation datas to target vector database . - """ - click.echo(click.style("Starting annotation data migration.", fg="green")) - create_count = 0 - skipped_count = 0 - total_count = 0 - page = 1 - while True: - try: - # get apps info - per_page = 50 - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - apps = ( - session.query(App) - .where(App.status == "normal") - .order_by(App.created_at.desc()) - .limit(per_page) - .offset((page - 1) * per_page) - .all() - ) - if not apps: - break - except SQLAlchemyError: - raise - - page += 1 - for app in apps: - total_count = total_count + 1 - click.echo( - f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." - ) - try: - click.echo(f"Creating app annotation index: {app.id}") - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - app_annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() - ) - - if not app_annotation_setting: - skipped_count = skipped_count + 1 - click.echo(f"App annotation setting disabled: {app.id}") - continue - # get dataset_collection_binding info - dataset_collection_binding = ( - session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) - .first() - ) - if not dataset_collection_binding: - click.echo(f"App annotation collection binding not found: {app.id}") - continue - annotations = session.scalars( - select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) - ).all() - dataset = Dataset( - id=app.id, - tenant_id=app.tenant_id, - indexing_technique="high_quality", - embedding_model_provider=dataset_collection_binding.provider_name, - embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id, - ) - documents = [] - if annotations: - for annotation in annotations: - document = Document( - page_content=annotation.question_text, - metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, - ) - documents.append(document) - - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - click.echo(f"Migrating annotations for app: {app.id}.") - - try: - vector.delete() - click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green")) - except Exception as e: - click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) - raise e - if documents: - try: - click.echo( - click.style( - f"Creating vector index with {len(documents)} annotations for app {app.id}.", - fg="green", - ) - ) - vector.create(documents) - click.echo(click.style(f"Created vector index for app {app.id}.", fg="green")) - except Exception as e: - click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) - raise e - click.echo(f"Successfully migrated app annotation {app.id}.") - create_count += 1 - except Exception as e: - click.echo( - click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red") - ) - continue - - click.echo( - click.style( - f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.", - fg="green", - ) - ) - - -def migrate_knowledge_vector_database(): - """ - Migrate vector database datas to target vector database . - """ - click.echo(click.style("Starting vector database migration.", fg="green")) - create_count = 0 - skipped_count = 0 - total_count = 0 - vector_type = dify_config.VECTOR_STORE - upper_collection_vector_types = { - VectorType.MILVUS, - VectorType.PGVECTOR, - VectorType.VASTBASE, - VectorType.RELYT, - VectorType.WEAVIATE, - VectorType.ORACLE, - VectorType.ELASTICSEARCH, - VectorType.OPENGAUSS, - VectorType.TABLESTORE, - VectorType.MATRIXONE, - } - lower_collection_vector_types = { - VectorType.ANALYTICDB, - VectorType.CHROMA, - VectorType.MYSCALE, - VectorType.PGVECTO_RS, - VectorType.TIDB_VECTOR, - VectorType.OPENSEARCH, - VectorType.TENCENT, - VectorType.BAIDU, - VectorType.VIKINGDB, - VectorType.UPSTASH, - VectorType.COUCHBASE, - VectorType.OCEANBASE, - } - page = 1 - while True: - try: - stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) - ) - - datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - if not datasets.items: - break - except SQLAlchemyError: - raise - - page += 1 - for dataset in datasets: - total_count = total_count + 1 - click.echo( - f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." - ) - try: - click.echo(f"Creating dataset vector database index: {dataset.id}") - if dataset.index_struct_dict: - if dataset.index_struct_dict["type"] == vector_type: - skipped_count = skipped_count + 1 - continue - collection_name = "" - dataset_id = dataset.id - if vector_type in upper_collection_vector_types: - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - elif vector_type == VectorType.QDRANT: - if dataset.collection_binding_id: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == dataset.collection_binding_id) - .one_or_none() - ) - if dataset_collection_binding: - collection_name = dataset_collection_binding.collection_name - else: - raise ValueError("Dataset Collection Binding not found") - else: - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - - elif vector_type in lower_collection_vector_types: - collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - else: - raise ValueError(f"Vector store {vector_type} is not supported.") - - index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - vector = Vector(dataset) - click.echo(f"Migrating dataset {dataset.id}.") - - try: - vector.delete() - click.echo( - click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green") - ) - except Exception as e: - click.echo( - click.style( - f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" - ) - ) - raise e - - dataset_documents = db.session.scalars( - select(DatasetDocument).where( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - ).all() - - documents = [] - segments_count = 0 - for dataset_document in dataset_documents: - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - ) - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - if dataset_document.doc_form == "hierarchical_model": - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - - documents.append(document) - segments_count = segments_count + 1 - - if documents: - try: - click.echo( - click.style( - f"Creating vector index with {len(documents)} documents of {segments_count}" - f" segments for dataset {dataset.id}.", - fg="green", - ) - ) - all_child_documents = [] - for doc in documents: - if doc.children: - all_child_documents.extend(doc.children) - vector.create(documents) - if all_child_documents: - vector.create(all_child_documents) - click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green")) - except Exception as e: - click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) - raise e - db.session.add(dataset) - db.session.commit() - click.echo(f"Successfully migrated dataset {dataset.id}.") - create_count += 1 - except Exception as e: - db.session.rollback() - click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red")) - continue - - click.echo( - click.style( - f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green" - ) - ) - - -@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") -def convert_to_agent_apps(): - """ - Convert Agent Assistant to Agent App. - """ - click.echo(click.style("Starting convert to agent apps.", fg="green")) - - proceeded_app_ids = [] - - while True: - # fetch first 1000 apps - sql_query = """SELECT a.id AS id FROM apps a - INNER JOIN app_model_configs am ON a.app_model_config_id=am.id - WHERE a.mode = 'chat' - AND am.agent_mode is not null - AND ( - am.agent_mode like '%"strategy": "function_call"%' - OR am.agent_mode like '%"strategy": "react"%' - ) - AND ( - am.agent_mode like '{"enabled": true%' - OR am.agent_mode like '{"max_iteration": %' - ) ORDER BY a.created_at DESC LIMIT 1000 - """ - - with db.engine.begin() as conn: - rs = conn.execute(sa.text(sql_query)) - - apps = [] - for i in rs: - app_id = str(i.id) - if app_id not in proceeded_app_ids: - proceeded_app_ids.append(app_id) - app = db.session.query(App).where(App.id == app_id).first() - if app is not None: - apps.append(app) - - if len(apps) == 0: - break - - for app in apps: - click.echo(f"Converting app: {app.id}") - - try: - app.mode = AppMode.AGENT_CHAT - db.session.commit() - - # update conversation mode to agent - db.session.query(Conversation).where(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT_CHAT} - ) - - db.session.commit() - click.echo(click.style(f"Converted app: {app.id}", fg="green")) - except Exception as e: - click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red")) - - click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green")) - - -@click.command("add-qdrant-index", help="Add Qdrant index.") -@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.") -def add_qdrant_index(field: str): - click.echo(click.style("Starting Qdrant index creation.", fg="green")) - - create_count = 0 - - try: - bindings = db.session.query(DatasetCollectionBinding).all() - if not bindings: - click.echo(click.style("No dataset collection bindings found.", fg="red")) - return - import qdrant_client - from qdrant_client.http.exceptions import UnexpectedResponse - from qdrant_client.http.models import PayloadSchemaType - - from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig - - for binding in bindings: - if dify_config.QDRANT_URL is None: - raise ValueError("Qdrant URL is required.") - qdrant_config = QdrantConfig( - endpoint=dify_config.QDRANT_URL, - api_key=dify_config.QDRANT_API_KEY, - root_path=current_app.root_path, - timeout=dify_config.QDRANT_CLIENT_TIMEOUT, - grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, - ) - try: - params = qdrant_config.to_qdrant_params() - # Check the type before using - if isinstance(params, PathQdrantParams): - # PathQdrantParams case - client = qdrant_client.QdrantClient(path=params.path) - else: - # UrlQdrantParams case - params is UrlQdrantParams - client = qdrant_client.QdrantClient( - url=params.url, - api_key=params.api_key, - timeout=int(params.timeout), - verify=params.verify, - grpc_port=params.grpc_port, - prefer_grpc=params.prefer_grpc, - ) - # create payload index - client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) - create_count += 1 - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red")) - continue - # Some other error occurred, so re-raise the exception - else: - click.echo( - click.style( - f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red" - ) - ) - - except Exception: - click.echo(click.style("Failed to create Qdrant client.", fg="red")) - - click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) - - -@click.command("old-metadata-migration", help="Old metadata migration.") -def old_metadata_migration(): - """ - Old metadata migration. - """ - click.echo(click.style("Starting old metadata migration.", fg="green")) - - page = 1 - while True: - try: - stmt = ( - select(DatasetDocument) - .where(DatasetDocument.doc_metadata.is_not(None)) - .order_by(DatasetDocument.created_at.desc()) - ) - documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except SQLAlchemyError: - raise - if not documents: - break - for document in documents: - if document.doc_metadata: - doc_metadata = document.doc_metadata - for key in doc_metadata: - for field in BuiltInField: - if field.value == key: - break - else: - dataset_metadata = ( - db.session.query(DatasetMetadata) - .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) - .first() - ) - if not dataset_metadata: - dataset_metadata = DatasetMetadata( - tenant_id=document.tenant_id, - dataset_id=document.dataset_id, - name=key, - type="string", - created_by=document.created_by, - ) - db.session.add(dataset_metadata) - db.session.flush() - dataset_metadata_binding = DatasetMetadataBinding( - tenant_id=document.tenant_id, - dataset_id=document.dataset_id, - metadata_id=dataset_metadata.id, - document_id=document.id, - created_by=document.created_by, - ) - db.session.add(dataset_metadata_binding) - else: - dataset_metadata_binding = ( - db.session.query(DatasetMetadataBinding) # type: ignore - .where( - DatasetMetadataBinding.dataset_id == document.dataset_id, - DatasetMetadataBinding.document_id == document.id, - DatasetMetadataBinding.metadata_id == dataset_metadata.id, - ) - .first() - ) - if not dataset_metadata_binding: - dataset_metadata_binding = DatasetMetadataBinding( - tenant_id=document.tenant_id, - dataset_id=document.dataset_id, - metadata_id=dataset_metadata.id, - document_id=document.id, - created_by=document.created_by, - ) - db.session.add(dataset_metadata_binding) - db.session.commit() - page += 1 - click.echo(click.style("Old metadata migration completed.", fg="green")) - - -@click.command("create-tenant", help="Create account and tenant.") -@click.option("--email", prompt=True, help="Tenant account email.") -@click.option("--name", prompt=True, help="Workspace name.") -@click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: str | None = None, name: str | None = None): - """ - Create tenant account - """ - if not email: - click.echo(click.style("Email is required.", fg="red")) - return - - # Create account - email = email.strip().lower() - - if "@" not in email: - click.echo(click.style("Invalid email address.", fg="red")) - return - - account_name = email.split("@")[0] - - if language not in languages: - language = "en-US" - - # Validates name encoding for non-Latin characters. - name = name.strip().encode("utf-8").decode("utf-8") if name else None - - # generate random password - new_password = secrets.token_urlsafe(16) - - # register account - account = RegisterService.register( - email=email, - name=account_name, - password=new_password, - language=language, - create_workspace_required=False, - ) - TenantService.create_owner_tenant_if_not_exist(account, name) - - click.echo( - click.style( - f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}", - fg="green", - ) - ) - - -@click.command("upgrade-db", help="Upgrade the database") -def upgrade_db(): - click.echo("Preparing database migration...") - lock = DbMigrationAutoRenewLock( - redis_client=redis_client, - name="db_upgrade_lock", - ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS, - logger=logger, - log_context="db_migration", - ) - if lock.acquire(blocking=False): - migration_succeeded = False - try: - click.echo(click.style("Starting database migration.", fg="green")) - - # run db migration - import flask_migrate - - flask_migrate.upgrade() - - migration_succeeded = True - click.echo(click.style("Database migration successful!", fg="green")) - - except Exception as e: - logger.exception("Failed to execute database migration") - click.echo(click.style(f"Database migration failed: {e}", fg="red")) - raise SystemExit(1) - finally: - status = "successful" if migration_succeeded else "failed" - lock.release_safely(status=status) - else: - click.echo("Database migration skipped") - - -@click.command("fix-app-site-missing", help="Fix app related site missing issue.") -def fix_app_site_missing(): - """ - Fix app related site missing issue. - """ - click.echo(click.style("Starting fix for missing app-related sites.", fg="green")) - - failed_app_ids = [] - while True: - sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id -where sites.id is null limit 1000""" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(sql)) - - processed_count = 0 - for i in rs: - processed_count += 1 - app_id = str(i.id) - - if app_id in failed_app_ids: - continue - - try: - app = db.session.query(App).where(App.id == app_id).first() - if not app: - logger.info("App %s not found", app_id) - continue - - tenant = app.tenant - if tenant: - accounts = tenant.get_accounts() - if not accounts: - logger.info("Fix failed for app %s", app.id) - continue - - account = accounts[0] - logger.info("Fixing missing site for app %s", app.id) - app_was_created.send(app, account=account) - except Exception: - failed_app_ids.append(app_id) - click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red")) - logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id) - continue - - if not processed_count: - break - - click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) - - -@click.command("migrate-data-for-plugin", help="Migrate data for plugin.") -def migrate_data_for_plugin(): - """ - Migrate data for plugin. - """ - click.echo(click.style("Starting migrate data for plugin.", fg="white")) - - PluginDataMigration.migrate() - - click.echo(click.style("Migrate data for plugin completed.", fg="green")) - - -@click.command("extract-plugins", help="Extract plugins.") -@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl") -@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10) -def extract_plugins(output_file: str, workers: int): - """ - Extract plugins. - """ - click.echo(click.style("Starting extract plugins.", fg="white")) - - PluginMigration.extract_plugins(output_file, workers) - - click.echo(click.style("Extract plugins completed.", fg="green")) - - -@click.command("extract-unique-identifiers", help="Extract unique identifiers.") -@click.option( - "--output_file", - prompt=True, - help="The file to store the extracted unique identifiers.", - default="unique_identifiers.json", -) -@click.option( - "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" -) -def extract_unique_plugins(output_file: str, input_file: str): - """ - Extract unique plugins. - """ - click.echo(click.style("Starting extract unique plugins.", fg="white")) - - PluginMigration.extract_unique_plugins_to_file(input_file, output_file) - - click.echo(click.style("Extract unique plugins completed.", fg="green")) - - -@click.command("install-plugins", help="Install plugins.") -@click.option( - "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" -) -@click.option( - "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" -) -@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) -def install_plugins(input_file: str, output_file: str, workers: int): - """ - Install plugins. - """ - click.echo(click.style("Starting install plugins.", fg="white")) - - PluginMigration.install_plugins(input_file, output_file, workers) - - click.echo(click.style("Install plugins completed.", fg="green")) - - -@click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.") -@click.option("--days", prompt=True, help="The days to clear free plan tenant expired logs.", default=30) -@click.option("--batch", prompt=True, help="The batch size to clear free plan tenant expired logs.", default=100) -@click.option( - "--tenant_ids", - prompt=True, - multiple=True, - help="The tenant ids to clear free plan tenant expired logs.", -) -def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[str]): - """ - Clear free plan tenant expired logs. - """ - click.echo(click.style("Starting clear free plan tenant expired logs.", fg="white")) - - ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids) - - click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) - - -@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") -@click.option( - "--before-days", - "--days", - default=30, - show_default=True, - type=click.IntRange(min=0), - help="Delete workflow runs created before N days ago.", -) -@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") -@click.option( - "--from-days-ago", - default=None, - type=click.IntRange(min=0), - help="Lower bound in days ago (older). Must be paired with --to-days-ago.", -) -@click.option( - "--to-days-ago", - default=None, - type=click.IntRange(min=0), - help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", -) -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", -) -@click.option( - "--dry-run", - is_flag=True, - help="Preview cleanup results without deleting any workflow run data.", -) -def clean_workflow_runs( - before_days: int, - batch_size: int, - from_days_ago: int | None, - to_days_ago: int | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - dry_run: bool, -): - """ - Clean workflow runs and related workflow data for free tenants. - """ - if (start_from is None) ^ (end_before is None): - raise click.UsageError("--start-from and --end-before must be provided together.") - - if (from_days_ago is None) ^ (to_days_ago is None): - raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.") - - if from_days_ago is not None and to_days_ago is not None: - if start_from or end_before: - raise click.UsageError("Choose either day offsets or explicit dates, not both.") - if from_days_ago <= to_days_ago: - raise click.UsageError("--from-days-ago must be greater than --to-days-ago.") - now = datetime.datetime.now() - start_from = now - datetime.timedelta(days=from_days_ago) - end_before = now - datetime.timedelta(days=to_days_ago) - before_days = 0 - - start_time = datetime.datetime.now(datetime.UTC) - click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) - - WorkflowRunCleanup( - days=before_days, - batch_size=batch_size, - start_from=start_from, - end_before=end_before, - dry_run=dry_run, - ).run() - - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time - click.echo( - click.style( - f"Workflow run cleanup completed. start={start_time.isoformat()} " - f"end={end_time.isoformat()} duration={elapsed}", - fg="green", - ) - ) - - -@click.command( - "archive-workflow-runs", - help="Archive workflow runs for paid plan tenants to S3-compatible storage.", -) -@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") -@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") -@click.option( - "--from-days-ago", - default=None, - type=click.IntRange(min=0), - help="Lower bound in days ago (older). Must be paired with --to-days-ago.", -) -@click.option( - "--to-days-ago", - default=None, - type=click.IntRange(min=0), - help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", -) -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Archive runs created at or after this timestamp (UTC if no timezone).", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Archive runs created before this timestamp (UTC if no timezone).", -) -@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") -@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") -@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") -@click.option("--dry-run", is_flag=True, help="Preview without archiving.") -@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") -def archive_workflow_runs( - tenant_ids: str | None, - before_days: int, - from_days_ago: int | None, - to_days_ago: int | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - batch_size: int, - workers: int, - limit: int | None, - dry_run: bool, - delete_after_archive: bool, -): - """ - Archive workflow runs for paid plan tenants older than the specified days. - - This command archives the following tables to storage: - - workflow_node_executions - - workflow_node_execution_offload - - workflow_pauses - - workflow_pause_reasons - - workflow_trigger_logs - - The workflow_runs and workflow_app_logs tables are preserved for UI listing. - """ - from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver - - run_started_at = datetime.datetime.now(datetime.UTC) - click.echo( - click.style( - f"Starting workflow run archiving at {run_started_at.isoformat()}.", - fg="white", - ) - ) - - if (start_from is None) ^ (end_before is None): - click.echo(click.style("start-from and end-before must be provided together.", fg="red")) - return - - if (from_days_ago is None) ^ (to_days_ago is None): - click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) - return - - if from_days_ago is not None and to_days_ago is not None: - if start_from or end_before: - click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) - return - if from_days_ago <= to_days_ago: - click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) - return - now = datetime.datetime.now() - start_from = now - datetime.timedelta(days=from_days_ago) - end_before = now - datetime.timedelta(days=to_days_ago) - before_days = 0 - - if start_from and end_before and start_from >= end_before: - click.echo(click.style("start-from must be earlier than end-before.", fg="red")) - return - if workers < 1: - click.echo(click.style("workers must be at least 1.", fg="red")) - return - - archiver = WorkflowRunArchiver( - days=before_days, - batch_size=batch_size, - start_from=start_from, - end_before=end_before, - workers=workers, - tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, - limit=limit, - dry_run=dry_run, - delete_after_archive=delete_after_archive, - ) - summary = archiver.run() - click.echo( - click.style( - f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " - f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " - f"time={summary.total_elapsed_time:.2f}s", - fg="cyan", - ) - ) - - run_finished_at = datetime.datetime.now(datetime.UTC) - elapsed = run_finished_at - run_started_at - click.echo( - click.style( - f"Workflow run archiving completed. start={run_started_at.isoformat()} " - f"end={run_finished_at.isoformat()} duration={elapsed}", - fg="green", - ) - ) - - -@click.command( - "restore-workflow-runs", - help="Restore archived workflow runs from S3-compatible storage.", -) -@click.option( - "--tenant-ids", - required=False, - help="Tenant IDs (comma-separated).", -) -@click.option("--run-id", required=False, help="Workflow run ID to restore.") -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", -) -@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") -@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") -@click.option("--dry-run", is_flag=True, help="Preview without restoring.") -def restore_workflow_runs( - tenant_ids: str | None, - run_id: str | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - workers: int, - limit: int, - dry_run: bool, -): - """ - Restore an archived workflow run from storage to the database. - - This restores the following tables: - - workflow_node_executions - - workflow_node_execution_offload - - workflow_pauses - - workflow_pause_reasons - - workflow_trigger_logs - """ - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - parsed_tenant_ids = None - if tenant_ids: - parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] - if not parsed_tenant_ids: - raise click.BadParameter("tenant-ids must not be empty") - - if (start_from is None) ^ (end_before is None): - raise click.UsageError("--start-from and --end-before must be provided together.") - if run_id is None and (start_from is None or end_before is None): - raise click.UsageError("--start-from and --end-before are required for batch restore.") - if workers < 1: - raise click.BadParameter("workers must be at least 1") - - start_time = datetime.datetime.now(datetime.UTC) - click.echo( - click.style( - f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", - fg="white", - ) - ) - - restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) - if run_id: - results = [restorer.restore_by_run_id(run_id)] - else: - assert start_from is not None - assert end_before is not None - results = restorer.restore_batch( - parsed_tenant_ids, - start_date=start_from, - end_date=end_before, - limit=limit, - ) - - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time - - successes = sum(1 for result in results if result.success) - failures = len(results) - successes - - if failures == 0: - click.echo( - click.style( - f"Restore completed successfully. success={successes} duration={elapsed}", - fg="green", - ) - ) - else: - click.echo( - click.style( - f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", - fg="red", - ) - ) - - -@click.command( - "delete-archived-workflow-runs", - help="Delete archived workflow runs from the database.", -) -@click.option( - "--tenant-ids", - required=False, - help="Tenant IDs (comma-separated).", -) -@click.option("--run-id", required=False, help="Workflow run ID to delete.") -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", -) -@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") -@click.option("--dry-run", is_flag=True, help="Preview without deleting.") -def delete_archived_workflow_runs( - tenant_ids: str | None, - run_id: str | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - limit: int, - dry_run: bool, -): - """ - Delete archived workflow runs from the database. - """ - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - parsed_tenant_ids = None - if tenant_ids: - parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] - if not parsed_tenant_ids: - raise click.BadParameter("tenant-ids must not be empty") - - if (start_from is None) ^ (end_before is None): - raise click.UsageError("--start-from and --end-before must be provided together.") - if run_id is None and (start_from is None or end_before is None): - raise click.UsageError("--start-from and --end-before are required for batch delete.") - - start_time = datetime.datetime.now(datetime.UTC) - target_desc = f"workflow run {run_id}" if run_id else "workflow runs" - click.echo( - click.style( - f"Starting delete of {target_desc} at {start_time.isoformat()}.", - fg="white", - ) - ) - - deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) - if run_id: - results = [deleter.delete_by_run_id(run_id)] - else: - assert start_from is not None - assert end_before is not None - results = deleter.delete_batch( - parsed_tenant_ids, - start_date=start_from, - end_date=end_before, - limit=limit, - ) - - for result in results: - if result.success: - click.echo( - click.style( - f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " - f"workflow run {result.run_id} (tenant={result.tenant_id})", - fg="green", - ) - ) - else: - click.echo( - click.style( - f"Failed to delete workflow run {result.run_id}: {result.error}", - fg="red", - ) - ) - - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time - - successes = sum(1 for result in results if result.success) - failures = len(results) - successes - - if failures == 0: - click.echo( - click.style( - f"Delete completed successfully. success={successes} duration={elapsed}", - fg="green", - ) - ) - else: - click.echo( - click.style( - f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", - fg="red", - ) - ) - - -@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") -@click.command("clear-orphaned-file-records", help="Clear orphaned file records.") -def clear_orphaned_file_records(force: bool): - """ - Clear orphaned file records in the database. - """ - - # define tables and columns to process - files_tables = [ - {"table": "upload_files", "id_column": "id", "key_column": "key"}, - {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, - ] - ids_tables = [ - {"type": "uuid", "table": "message_files", "column": "upload_file_id"}, - {"type": "text", "table": "documents", "column": "data_source_info"}, - {"type": "text", "table": "document_segments", "column": "content"}, - {"type": "text", "table": "messages", "column": "answer"}, - {"type": "text", "table": "workflow_node_executions", "column": "inputs"}, - {"type": "text", "table": "workflow_node_executions", "column": "process_data"}, - {"type": "text", "table": "workflow_node_executions", "column": "outputs"}, - {"type": "text", "table": "conversations", "column": "introduction"}, - {"type": "text", "table": "conversations", "column": "system_instruction"}, - {"type": "text", "table": "accounts", "column": "avatar"}, - {"type": "text", "table": "apps", "column": "icon"}, - {"type": "text", "table": "sites", "column": "icon"}, - {"type": "json", "table": "messages", "column": "inputs"}, - {"type": "json", "table": "messages", "column": "message"}, - ] - - # notify user and ask for confirmation - click.echo( - click.style( - "This command will first find and delete orphaned file records from the message_files table,", fg="yellow" - ) - ) - click.echo( - click.style( - "and then it will find and delete orphaned file records in the following tables:", - fg="yellow", - ) - ) - for files_table in files_tables: - click.echo(click.style(f"- {files_table['table']}", fg="yellow")) - click.echo( - click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow") - ) - for ids_table in ids_tables: - click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow")) - click.echo("") - - click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) - click.echo( - click.style( - ( - "Since not all patterns have been fully tested, " - "please note that this command may delete unintended file records." - ), - fg="yellow", - ) - ) - click.echo( - click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow") - ) - click.echo( - click.style( - ( - "It is also recommended to run this during the maintenance window, " - "as this may cause high load on your instance." - ), - fg="yellow", - ) - ) - if not force: - click.confirm("Do you want to proceed?", abort=True) - - # start the cleanup process - click.echo(click.style("Starting orphaned file records cleanup.", fg="white")) - - # clean up the orphaned records in the message_files table where message_id doesn't exist in messages table - try: - click.echo( - click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white") - ) - query = ( - "SELECT mf.id, mf.message_id " - "FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id " - "WHERE m.id IS NULL" - ) - orphaned_message_files = [] - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) - - if orphaned_message_files: - click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white")) - for record in orphaned_message_files: - click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black")) - - if not force: - click.confirm( - ( - f"Do you want to proceed " - f"to delete all {len(orphaned_message_files)} orphaned message_files records?" - ), - abort=True, - ) - - click.echo(click.style("- Deleting orphaned message_files records", fg="white")) - query = "DELETE FROM message_files WHERE id IN :ids" - with db.engine.begin() as conn: - conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)}) - click.echo( - click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") - ) - else: - click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red")) - - # clean up the orphaned records in the rest of the *_files tables - try: - # fetch file id and keys from each table - all_files_in_tables = [] - for files_table in files_tables: - click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) - query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) - click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) - - # fetch referred table and columns - guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" - all_ids_in_tables = [] - for ids_table in ids_tables: - query = "" - match ids_table["type"]: - case "uuid": - click.echo( - click.style( - f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", - fg="white", - ) - ) - c = ids_table["column"] - query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) - case "text": - t = ids_table["table"] - click.echo( - click.style( - f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", - fg="white", - ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - case "json": - click.echo( - click.style( - ( - f"- Listing file-id-like JSON string in column {ids_table['column']} " - f"in table {ids_table['table']}" - ), - fg="white", - ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - case _: - pass - click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) - - except Exception as e: - click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) - return - - # find orphaned files - all_files = [file["id"] for file in all_files_in_tables] - all_ids = [file["id"] for file in all_ids_in_tables] - orphaned_files = list(set(all_files) - set(all_ids)) - if not orphaned_files: - click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green")) - return - click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white")) - for file in orphaned_files: - click.echo(click.style(f"- orphaned file id: {file}", fg="black")) - if not force: - click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True) - - # delete orphaned records for each file - try: - for files_table in files_tables: - click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) - query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" - with db.engine.begin() as conn: - conn.execute(sa.text(query), {"ids": tuple(orphaned_files)}) - except Exception as e: - click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) - return - click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green")) - - -@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") -@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.") -def remove_orphaned_files_on_storage(force: bool): - """ - Remove orphaned files on the storage. - """ - - # define tables and columns to process - files_tables = [ - {"table": "upload_files", "key_column": "key"}, - {"table": "tool_files", "key_column": "file_key"}, - ] - storage_paths = ["image_files", "tools", "upload_files"] - - # notify user and ask for confirmation - click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow")) - click.echo( - click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow") - ) - for files_table in files_tables: - click.echo(click.style(f"- {files_table['table']}", fg="yellow")) - click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow")) - for storage_path in storage_paths: - click.echo(click.style(f"- {storage_path}", fg="yellow")) - click.echo("") - - click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) - click.echo( - click.style( - "Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow" - ) - ) - click.echo( - click.style( - "Since not all patterns have been fully tested, please note that this command may delete unintended files.", - fg="yellow", - ) - ) - click.echo( - click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow") - ) - click.echo( - click.style( - ( - "It is also recommended to run this during the maintenance window, " - "as this may cause high load on your instance." - ), - fg="yellow", - ) - ) - if not force: - click.confirm("Do you want to proceed?", abort=True) - - # start the cleanup process - click.echo(click.style("Starting orphaned files cleanup.", fg="white")) - - # fetch file id and keys from each table - all_files_in_tables = [] - try: - for files_table in files_tables: - click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) - query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_files_in_tables.append(str(i[0])) - click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) - except Exception as e: - click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) - return - - all_files_on_storage = [] - for storage_path in storage_paths: - try: - click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white")) - files = storage.scan(path=storage_path, files=True, directories=False) - all_files_on_storage.extend(files) - except FileNotFoundError as e: - click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow")) - continue - except Exception as e: - click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red")) - continue - click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white")) - - # find orphaned files - orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables)) - if not orphaned_files: - click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green")) - return - click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white")) - for file in orphaned_files: - click.echo(click.style(f"- orphaned file: {file}", fg="black")) - if not force: - click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True) - - # delete orphaned files - removed_files = 0 - error_files = 0 - for file in orphaned_files: - try: - storage.delete(file) - removed_files += 1 - click.echo(click.style(f"- Removing orphaned file: {file}", fg="white")) - except Exception as e: - error_files += 1 - click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red")) - continue - if error_files == 0: - click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) - else: - click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) - - -@click.command("file-usage", help="Query file usages and show where files are referenced.") -@click.option("--file-id", type=str, default=None, help="Filter by file UUID.") -@click.option("--key", type=str, default=None, help="Filter by storage key.") -@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').") -@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).") -@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).") -@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.") -def file_usage( - file_id: str | None, - key: str | None, - src: str | None, - limit: int, - offset: int, - output_json: bool, -): - """ - Query file usages and show where files are referenced in the database. - - This command reuses the same reference checking logic as clear-orphaned-file-records - and displays detailed information about where each file is referenced. - """ - # define tables and columns to process - files_tables = [ - {"table": "upload_files", "id_column": "id", "key_column": "key"}, - {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, - ] - ids_tables = [ - {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"}, - {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"}, - {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"}, - {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"}, - {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"}, - {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"}, - {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"}, - {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"}, - {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"}, - {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"}, - {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"}, - {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"}, - {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"}, - {"type": "json", "table": "messages", "column": "message", "pk_column": "id"}, - ] - - # Stream file usages with pagination to avoid holding all results in memory - paginated_usages = [] - total_count = 0 - - # First, build a mapping of file_id -> storage_key from the base tables - file_key_map = {} - for files_table in files_tables: - query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}" - - # If filtering by key or file_id, verify it exists - if file_id and file_id not in file_key_map: - if output_json: - click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"})) - else: - click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red")) - return - - if key: - valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"} - matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes] - if not matching_file_ids: - if output_json: - click.echo(json.dumps({"error": f"Key {key} not found in base tables"})) - else: - click.echo(click.style(f"Key {key} not found in base tables.", fg="red")) - return - - guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" - - # For each reference table/column, find matching file IDs and record the references - for ids_table in ids_tables: - src_filter = f"{ids_table['table']}.{ids_table['column']}" - - # Skip if src filter doesn't match (use fnmatch for wildcard patterns) - if src: - if "%" in src or "_" in src: - import fnmatch - - # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?) - pattern = src.replace("%", "*").replace("_", "?") - if not fnmatch.fnmatch(src_filter, pattern): - continue - else: - if src_filter != src: - continue - - match ids_table["type"]: - case "uuid": - # Direct UUID match - query = ( - f"SELECT {ids_table['pk_column']}, {ids_table['column']} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - ref_file_id = str(row[1]) - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - - case "text" | "json": - # Extract UUIDs from text/json content - column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] - query = ( - f"SELECT {ids_table['pk_column']}, {column_cast} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - content = str(row[1]) - - # Find all UUIDs in the content - import re - - uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) - matches = uuid_pattern.findall(content) - - for ref_file_id in matches: - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - case _: - pass - - # Output results - if output_json: - result = { - "total": total_count, - "offset": offset, - "limit": limit, - "usages": paginated_usages, - } - click.echo(json.dumps(result, indent=2)) - else: - click.echo( - click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white") - ) - click.echo("") - - if not paginated_usages: - click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow")) - return - - # Print table header - click.echo( - click.style( - f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}", - fg="cyan", - ) - ) - click.echo(click.style("-" * 190, fg="white")) - - # Print each usage - for usage in paginated_usages: - click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}") - - # Show pagination info - if offset + limit < total_count: - click.echo("") - click.echo( - click.style( - f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white" - ) - ) - click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white")) - - -@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") -@click.option("--provider", prompt=True, help="Provider name") -@click.option("--client-params", prompt=True, help="Client Params") -def setup_system_tool_oauth_client(provider, client_params): - """ - Setup system tool oauth client - """ - provider_id = ToolProviderID(provider) - provider_name = provider_id.provider_name - plugin_id = provider_id.plugin_id - - try: - # json validate - click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) - client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) - click.echo(click.style("Client params validated successfully.", fg="green")) - - click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) - click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) - click.echo(click.style("Client params encrypted successfully.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - - deleted_count = ( - db.session.query(ToolOAuthSystemClient) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) - if deleted_count > 0: - click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) - - oauth_client = ToolOAuthSystemClient( - provider=provider_name, - plugin_id=plugin_id, - encrypted_oauth_params=oauth_client_params, - ) - db.session.add(oauth_client) - db.session.commit() - click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) - - -@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.") -@click.option("--provider", prompt=True, help="Provider name") -@click.option("--client-params", prompt=True, help="Client Params") -def setup_system_trigger_oauth_client(provider, client_params): - """ - Setup system trigger oauth client - """ - from models.provider_ids import TriggerProviderID - from models.trigger import TriggerOAuthSystemClient - - provider_id = TriggerProviderID(provider) - provider_name = provider_id.provider_name - plugin_id = provider_id.plugin_id - - try: - # json validate - click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) - client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) - click.echo(click.style("Client params validated successfully.", fg="green")) - - click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) - click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) - click.echo(click.style("Client params encrypted successfully.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - - deleted_count = ( - db.session.query(TriggerOAuthSystemClient) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) - if deleted_count > 0: - click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) - - oauth_client = TriggerOAuthSystemClient( - provider=provider_name, - plugin_id=plugin_id, - encrypted_oauth_params=oauth_client_params, - ) - db.session.add(oauth_client) - db.session.commit() - click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) - - -def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: - """ - Find draft variables that reference non-existent apps. - - Args: - batch_size: Maximum number of orphaned app IDs to return - - Returns: - List of app IDs that have draft variables but don't exist in the apps table - """ - query = """ - SELECT DISTINCT wdv.app_id - FROM workflow_draft_variables AS wdv - WHERE NOT EXISTS( - SELECT 1 FROM apps WHERE apps.id = wdv.app_id - ) - LIMIT :batch_size - """ - - with db.engine.connect() as conn: - result = conn.execute(sa.text(query), {"batch_size": batch_size}) - return [row[0] for row in result] - - -def _count_orphaned_draft_variables() -> dict[str, Any]: - """ - Count orphaned draft variables by app, including associated file counts. - - Returns: - Dictionary with statistics about orphaned variables and files - """ - # Count orphaned variables by app - variables_query = """ - SELECT - wdv.app_id, - COUNT(*) as variable_count, - COUNT(wdv.file_id) as file_count - FROM workflow_draft_variables AS wdv - WHERE NOT EXISTS( - SELECT 1 FROM apps WHERE apps.id = wdv.app_id - ) - GROUP BY wdv.app_id - ORDER BY variable_count DESC - """ - - with db.engine.connect() as conn: - result = conn.execute(sa.text(variables_query)) - orphaned_by_app = {} - total_files = 0 - - for row in result: - app_id, variable_count, file_count = row - orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count} - total_files += file_count - - total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values()) - app_count = len(orphaned_by_app) - - return { - "total_orphaned_variables": total_orphaned, - "total_orphaned_files": total_files, - "orphaned_app_count": app_count, - "orphaned_by_app": orphaned_by_app, - } - - -@click.command() -@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting") -@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)") -@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)") -@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") -def cleanup_orphaned_draft_variables( - dry_run: bool, - batch_size: int, - max_apps: int | None, - force: bool = False, -): - """ - Clean up orphaned draft variables from the database. - - This script finds and removes draft variables that belong to apps - that no longer exist in the database. - """ - logger = logging.getLogger(__name__) - - # Get statistics - stats = _count_orphaned_draft_variables() - - logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"]) - logger.info("Found %s associated offload files", stats["total_orphaned_files"]) - logger.info("Across %s non-existent apps", stats["orphaned_app_count"]) - - if stats["total_orphaned_variables"] == 0: - logger.info("No orphaned draft variables found. Exiting.") - return - - if dry_run: - logger.info("DRY RUN: Would delete the following:") - for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[ - :10 - ]: # Show top 10 - logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"]) - if len(stats["orphaned_by_app"]) > 10: - logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10) - return - - # Confirm deletion - if not force: - click.confirm( - f"Are you sure you want to delete {stats['total_orphaned_variables']} " - f"orphaned draft variables and {stats['total_orphaned_files']} associated files " - f"from {stats['orphaned_app_count']} apps?", - abort=True, - ) - - total_deleted = 0 - processed_apps = 0 - - while True: - if max_apps and processed_apps >= max_apps: - logger.info("Reached maximum app limit (%s). Stopping.", max_apps) - break - - orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10) - if not orphaned_app_ids: - logger.info("No more orphaned draft variables found.") - break - - for app_id in orphaned_app_ids: - if max_apps and processed_apps >= max_apps: - break - - try: - deleted_count = delete_draft_variables_batch(app_id, batch_size) - total_deleted += deleted_count - processed_apps += 1 - - logger.info("Deleted %s variables for app %s", deleted_count, app_id) - - except Exception: - logger.exception("Error processing app %s", app_id) - continue - - logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) - - -@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.") -@click.option("--provider", prompt=True, help="Provider name") -@click.option("--client-params", prompt=True, help="Client Params") -def setup_datasource_oauth_client(provider, client_params): - """ - Setup datasource oauth client - """ - provider_id = DatasourceProviderID(provider) - provider_name = provider_id.provider_name - plugin_id = provider_id.plugin_id - - try: - # json validate - click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) - client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) - click.echo(click.style("Client params validated successfully.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - - click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow")) - deleted_count = ( - db.session.query(DatasourceOauthParamConfig) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) - if deleted_count > 0: - click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) - - click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow")) - oauth_client = DatasourceOauthParamConfig( - provider=provider_name, - plugin_id=plugin_id, - system_credentials=client_params_dict, - ) - db.session.add(oauth_client) - db.session.commit() - click.echo(click.style(f"provider: {provider_name}", fg="green")) - click.echo(click.style(f"plugin_id: {plugin_id}", fg="green")) - click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) - click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) - - -@click.command("transform-datasource-credentials", help="Transform datasource credentials.") -@click.option( - "--environment", prompt=True, help="the environment to transform datasource credentials", default="online" -) -def transform_datasource_credentials(environment: str): - """ - Transform datasource credentials - """ - try: - installer_manager = PluginInstaller() - plugin_migration = PluginMigration() - - notion_plugin_id = "langgenius/notion_datasource" - firecrawl_plugin_id = "langgenius/firecrawl_datasource" - jina_plugin_id = "langgenius/jina_datasource" - if environment == "online": - notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage] - firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage] - jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage] - else: - notion_plugin_unique_identifier = None - firecrawl_plugin_unique_identifier = None - jina_plugin_unique_identifier = None - oauth_credential_type = CredentialType.OAUTH2 - api_key_credential_type = CredentialType.API_KEY - - # deal notion credentials - deal_notion_count = 0 - notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() - if notion_credentials: - notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} - for notion_credential in notion_credentials: - tenant_id = notion_credential.tenant_id - if tenant_id not in notion_credentials_tenant_mapping: - notion_credentials_tenant_mapping[tenant_id] = [] - notion_credentials_tenant_mapping[tenant_id].append(notion_credential) - for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() - if not tenant: - continue - try: - # check notion plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if notion_plugin_id not in installed_plugins_ids: - if notion_plugin_unique_identifier: - # install notion plugin - PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) - auth_count = 0 - for notion_tenant_credential in notion_tenant_credentials: - auth_count += 1 - # get credential oauth params - access_token = notion_tenant_credential.access_token - # notion info - notion_info = notion_tenant_credential.source_info - workspace_id = notion_info.get("workspace_id") - workspace_name = notion_info.get("workspace_name") - workspace_icon = notion_info.get("workspace_icon") - new_credentials = { - "integration_secret": encrypter.encrypt_token(tenant_id, access_token), - "workspace_id": workspace_id, - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - } - datasource_provider = DatasourceProvider( - provider="notion_datasource", - tenant_id=tenant_id, - plugin_id=notion_plugin_id, - auth_type=oauth_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url=workspace_icon or "default", - is_default=False, - ) - db.session.add(datasource_provider) - deal_notion_count += 1 - except Exception as e: - click.echo( - click.style( - f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" - ) - ) - continue - db.session.commit() - # deal firecrawl credentials - deal_firecrawl_count = 0 - firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all() - if firecrawl_credentials: - firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} - for firecrawl_credential in firecrawl_credentials: - tenant_id = firecrawl_credential.tenant_id - if tenant_id not in firecrawl_credentials_tenant_mapping: - firecrawl_credentials_tenant_mapping[tenant_id] = [] - firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) - for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() - if not tenant: - continue - try: - # check firecrawl plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if firecrawl_plugin_id not in installed_plugins_ids: - if firecrawl_plugin_unique_identifier: - # install firecrawl plugin - PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) - - auth_count = 0 - for firecrawl_tenant_credential in firecrawl_tenant_credentials: - auth_count += 1 - if not firecrawl_tenant_credential.credentials: - click.echo( - click.style( - f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.", - fg="yellow", - ) - ) - continue - # get credential api key - credentials_json = json.loads(firecrawl_tenant_credential.credentials) - api_key = credentials_json.get("config", {}).get("api_key") - base_url = credentials_json.get("config", {}).get("base_url") - new_credentials = { - "firecrawl_api_key": api_key, - "base_url": base_url, - } - datasource_provider = DatasourceProvider( - provider="firecrawl", - tenant_id=tenant_id, - plugin_id=firecrawl_plugin_id, - auth_type=api_key_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url="default", - is_default=False, - ) - db.session.add(datasource_provider) - deal_firecrawl_count += 1 - except Exception as e: - click.echo( - click.style( - f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" - ) - ) - continue - db.session.commit() - # deal jina credentials - deal_jina_count = 0 - jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all() - if jina_credentials: - jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} - for jina_credential in jina_credentials: - tenant_id = jina_credential.tenant_id - if tenant_id not in jina_credentials_tenant_mapping: - jina_credentials_tenant_mapping[tenant_id] = [] - jina_credentials_tenant_mapping[tenant_id].append(jina_credential) - for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() - if not tenant: - continue - try: - # check jina plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if jina_plugin_id not in installed_plugins_ids: - if jina_plugin_unique_identifier: - # install jina plugin - logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) - PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) - - auth_count = 0 - for jina_tenant_credential in jina_tenant_credentials: - auth_count += 1 - if not jina_tenant_credential.credentials: - click.echo( - click.style( - f"Skipping jina credential for tenant {tenant_id} due to missing credentials.", - fg="yellow", - ) - ) - continue - # get credential api key - credentials_json = json.loads(jina_tenant_credential.credentials) - api_key = credentials_json.get("config", {}).get("api_key") - new_credentials = { - "integration_secret": api_key, - } - datasource_provider = DatasourceProvider( - provider="jinareader", - tenant_id=tenant_id, - plugin_id=jina_plugin_id, - auth_type=api_key_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url="default", - is_default=False, - ) - db.session.add(datasource_provider) - deal_jina_count += 1 - except Exception as e: - click.echo( - click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red") - ) - continue - db.session.commit() - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green")) - click.echo( - click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green") - ) - click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green")) - - -@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.") -@click.option( - "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" -) -@click.option( - "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" -) -@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) -def install_rag_pipeline_plugins(input_file, output_file, workers): - """ - Install rag pipeline plugins - """ - click.echo(click.style("Installing rag pipeline plugins", fg="yellow")) - plugin_migration = PluginMigration() - plugin_migration.install_rag_pipeline_plugins( - input_file, - output_file, - workers, - ) - click.echo(click.style("Installing rag pipeline plugins successfully", fg="green")) - - -@click.command( - "migrate-oss", - help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).", -) -@click.option( - "--path", - "paths", - multiple=True, - help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files," - " tools, website_files, keyword_files, ops_trace", -) -@click.option( - "--source", - type=click.Choice(["local", "opendal"], case_sensitive=False), - default="opendal", - show_default=True, - help="Source storage type to read from", -) -@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists") -@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading") -@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts") -@click.option( - "--update-db/--no-update-db", - default=True, - help="Update upload_files.storage_type from source type to current storage after migration", -) -def migrate_oss( - paths: tuple[str, ...], - source: str, - overwrite: bool, - dry_run: bool, - force: bool, - update_db: bool, -): - """ - Copy all files under selected prefixes from a source storage - (Local filesystem or OpenDAL-backed) into the currently configured - destination storage backend, then optionally update DB records. - - Expected usage: set STORAGE_TYPE (and its credentials) to your target backend. - """ - # Ensure target storage is not local/opendal - if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL): - click.echo( - click.style( - "Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n" - "Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n" - "volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.", - fg="red", - ) - ) - return - - # Default paths if none specified - default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace") - path_list = list(paths) if paths else list(default_paths) - is_source_local = source.lower() == "local" - - click.echo(click.style("Preparing migration to target storage.", fg="yellow")) - click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white")) - if is_source_local: - src_root = dify_config.STORAGE_LOCAL_PATH - click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white")) - else: - click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white")) - click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white")) - click.echo("") - - if not force: - click.confirm("Proceed with migration?", abort=True) - - # Instantiate source storage - try: - if is_source_local: - src_root = dify_config.STORAGE_LOCAL_PATH - source_storage = OpenDALStorage(scheme="fs", root=src_root) - else: - source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME) - except Exception as e: - click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red")) - return - - total_files = 0 - copied_files = 0 - skipped_files = 0 - errored_files = 0 - copied_upload_file_keys: list[str] = [] - - for prefix in path_list: - click.echo(click.style(f"Scanning source path: {prefix}", fg="white")) - try: - keys = source_storage.scan(path=prefix, files=True, directories=False) - except FileNotFoundError: - click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow")) - continue - except NotImplementedError: - click.echo(click.style(" -> Source storage does not support scanning.", fg="red")) - return - except Exception as e: - click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red")) - continue - - click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white")) - - for key in keys: - total_files += 1 - - # check destination existence - if not overwrite: - try: - if storage.exists(key): - skipped_files += 1 - continue - except Exception as e: - # existence check failures should not block migration attempt - # but should be surfaced to user as a warning for visibility - click.echo( - click.style( - f" -> Warning: failed target existence check for {key}: {str(e)}", - fg="yellow", - ) - ) - - if dry_run: - copied_files += 1 - continue - - # read from source and write to destination - try: - data = source_storage.load_once(key) - except FileNotFoundError: - errored_files += 1 - click.echo(click.style(f" -> Missing on source: {key}", fg="yellow")) - continue - except Exception as e: - errored_files += 1 - click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red")) - continue - - try: - storage.save(key, data) - copied_files += 1 - if prefix == "upload_files": - copied_upload_file_keys.append(key) - except Exception as e: - errored_files += 1 - click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red")) - continue - - click.echo("") - click.echo(click.style("Migration summary:", fg="yellow")) - click.echo(click.style(f" Total: {total_files}", fg="white")) - click.echo(click.style(f" Copied: {copied_files}", fg="green")) - click.echo(click.style(f" Skipped: {skipped_files}", fg="white")) - if errored_files: - click.echo(click.style(f" Errors: {errored_files}", fg="red")) - - if dry_run: - click.echo(click.style("Dry-run complete. No changes were made.", fg="green")) - return - - if errored_files: - click.echo( - click.style( - "Some files failed to migrate. Review errors above before updating DB records.", - fg="yellow", - ) - ) - if update_db and not force: - if not click.confirm("Proceed to update DB storage_type despite errors?", default=False): - update_db = False - - # Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files) - if update_db: - if not copied_upload_file_keys: - click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow")) - else: - try: - source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL - updated = ( - db.session.query(UploadFile) - .where( - UploadFile.storage_type == source_storage_type, - UploadFile.key.in_(copied_upload_file_keys), - ) - .update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False) - ) - db.session.commit() - click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green")) - except Exception as e: - db.session.rollback() - click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) - - -@click.command("clean-expired-messages", help="Clean expired messages.") -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - required=True, - help="Lower bound (inclusive) for created_at.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - required=True, - help="Upper bound (exclusive) for created_at.", -) -@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") -@click.option( - "--graceful-period", - default=21, - show_default=True, - help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", -) -@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting") -def clean_expired_messages( - batch_size: int, - graceful_period: int, - start_from: datetime.datetime, - end_before: datetime.datetime, - dry_run: bool, -): - """ - Clean expired messages and related data for tenants based on clean policy. - """ - click.echo(click.style("clean_messages: start clean messages.", fg="green")) - - start_at = time.perf_counter() - - try: - # Create policy based on billing configuration - # NOTE: graceful_period will be ignored when billing is disabled. - policy = create_message_clean_policy(graceful_period_days=graceful_period) - - # Create and run the cleanup service - service = MessagesCleanService.from_time_range( - policy=policy, - start_from=start_from, - end_before=end_before, - batch_size=batch_size, - dry_run=dry_run, - ) - stats = service.run() - - end_at = time.perf_counter() - click.echo( - click.style( - f"clean_messages: completed successfully\n" - f" - Latency: {end_at - start_at:.2f}s\n" - f" - Batches processed: {stats['batches']}\n" - f" - Total messages scanned: {stats['total_messages']}\n" - f" - Messages filtered: {stats['filtered_messages']}\n" - f" - Messages deleted: {stats['total_deleted']}", - fg="green", - ) - ) - except Exception as e: - end_at = time.perf_counter() - logger.exception("clean_messages failed") - click.echo( - click.style( - f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", - fg="red", - ) - ) - raise - - click.echo(click.style("messages cleanup completed.", fg="green")) diff --git a/api/commands/__init__.py b/api/commands/__init__.py new file mode 100644 index 0000000000..d62d0dbd7c --- /dev/null +++ b/api/commands/__init__.py @@ -0,0 +1,71 @@ +""" +CLI command modules extracted from `commands.py`. +""" + +from .account import create_tenant, reset_email, reset_password +from .plugin import ( + extract_plugins, + extract_unique_plugins, + install_plugins, + install_rag_pipeline_plugins, + migrate_data_for_plugin, + setup_datasource_oauth_client, + setup_system_tool_oauth_client, + setup_system_trigger_oauth_client, + transform_datasource_credentials, +) +from .retention import ( + archive_workflow_runs, + clean_expired_messages, + clean_workflow_runs, + cleanup_orphaned_draft_variables, + clear_free_plan_tenant_expired_logs, + delete_archived_workflow_runs, + export_app_messages, + restore_workflow_runs, +) +from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage +from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db +from .vector import ( + add_qdrant_index, + migrate_annotation_vector_database, + migrate_knowledge_vector_database, + old_metadata_migration, + vdb_migrate, +) + +__all__ = [ + "add_qdrant_index", + "archive_workflow_runs", + "clean_expired_messages", + "clean_workflow_runs", + "cleanup_orphaned_draft_variables", + "clear_free_plan_tenant_expired_logs", + "clear_orphaned_file_records", + "convert_to_agent_apps", + "create_tenant", + "delete_archived_workflow_runs", + "export_app_messages", + "extract_plugins", + "extract_unique_plugins", + "file_usage", + "fix_app_site_missing", + "install_plugins", + "install_rag_pipeline_plugins", + "migrate_annotation_vector_database", + "migrate_data_for_plugin", + "migrate_knowledge_vector_database", + "migrate_oss", + "old_metadata_migration", + "remove_orphaned_files_on_storage", + "reset_email", + "reset_encrypt_key_pair", + "reset_password", + "restore_workflow_runs", + "setup_datasource_oauth_client", + "setup_system_tool_oauth_client", + "setup_system_trigger_oauth_client", + "transform_datasource_credentials", + "upgrade_db", + "vdb_migrate", +] diff --git a/api/commands/account.py b/api/commands/account.py new file mode 100644 index 0000000000..84af7a5ae6 --- /dev/null +++ b/api/commands/account.py @@ -0,0 +1,130 @@ +import base64 +import secrets + +import click +from sqlalchemy.orm import sessionmaker + +from constants.languages import languages +from extensions.ext_database import db +from libs.helper import email as email_validate +from libs.password import hash_password, password_pattern, valid_password +from services.account_service import AccountService, RegisterService, TenantService + + +@click.command("reset-password", help="Reset the account password.") +@click.option("--email", prompt=True, help="Account email to reset password for") +@click.option("--new-password", prompt=True, help="New password") +@click.option("--password-confirm", prompt=True, help="Confirm new password") +def reset_password(email, new_password, password_confirm): + """ + Reset password of owner account + Only available in SELF_HOSTED mode + """ + if str(new_password).strip() != str(password_confirm).strip(): + click.echo(click.style("Passwords do not match.", fg="red")) + return + normalized_email = email.strip().lower() + + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return + + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return + + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) + + +@click.command("reset-email", help="Reset the account email.") +@click.option("--email", prompt=True, help="Current account email") +@click.option("--new-email", prompt=True, help="New email") +@click.option("--email-confirm", prompt=True, help="Confirm new email") +def reset_email(email, new_email, email_confirm): + """ + Replace account email + :return: + """ + if str(new_email).strip() != str(email_confirm).strip(): + click.echo(click.style("New emails do not match.", fg="red")) + return + normalized_new_email = new_email.strip().lower() + + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return + + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return + + account.email = normalized_new_email + click.echo(click.style("Email updated successfully.", fg="green")) + + +@click.command("create-tenant", help="Create account and tenant.") +@click.option("--email", prompt=True, help="Tenant account email.") +@click.option("--name", prompt=True, help="Workspace name.") +@click.option("--language", prompt=True, help="Account language, default: en-US.") +def create_tenant(email: str, language: str | None = None, name: str | None = None): + """ + Create tenant account + """ + if not email: + click.echo(click.style("Email is required.", fg="red")) + return + + # Create account + email = email.strip().lower() + + if "@" not in email: + click.echo(click.style("Invalid email address.", fg="red")) + return + + account_name = email.split("@")[0] + + if language not in languages: + language = "en-US" + + # Validates name encoding for non-Latin characters. + name = name.strip().encode("utf-8").decode("utf-8") if name else None + + # generate random password + new_password = secrets.token_urlsafe(16) + + # register account + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language, + create_workspace_required=False, + ) + TenantService.create_owner_tenant_if_not_exist(account, name) + + click.echo( + click.style( + f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}", + fg="green", + ) + ) diff --git a/api/commands/plugin.py b/api/commands/plugin.py new file mode 100644 index 0000000000..2dfbd73b3a --- /dev/null +++ b/api/commands/plugin.py @@ -0,0 +1,467 @@ +import json +import logging +from typing import Any + +import click +from pydantic import TypeAdapter + +from configs import dify_config +from core.helper import encrypter +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.plugin import PluginInstaller +from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params +from extensions.ext_database import db +from models import Tenant +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider +from models.provider_ids import DatasourceProviderID, ToolProviderID +from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from models.tools import ToolOAuthSystemClient +from services.plugin.data_migration import PluginDataMigration +from services.plugin.plugin_migration import PluginMigration +from services.plugin.plugin_service import PluginService + +logger = logging.getLogger(__name__) + + +@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_tool_oauth_client(provider, client_params): + """ + Setup system tool oauth client + """ + provider_id = ToolProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = ( + db.session.query(ToolOAuthSystemClient) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = ToolOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + +@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_trigger_oauth_client(provider, client_params): + """ + Setup system trigger oauth client + """ + from models.provider_ids import TriggerProviderID + from models.trigger import TriggerOAuthSystemClient + + provider_id = TriggerProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = ( + db.session.query(TriggerOAuthSystemClient) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = TriggerOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + +@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_datasource_oauth_client(provider, client_params): + """ + Setup datasource oauth client + """ + provider_id = DatasourceProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow")) + deleted_count = ( + db.session.query(DatasourceOauthParamConfig) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow")) + oauth_client = DatasourceOauthParamConfig( + provider=provider_name, + plugin_id=plugin_id, + system_credentials=client_params_dict, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"provider: {provider_name}", fg="green")) + click.echo(click.style(f"plugin_id: {plugin_id}", fg="green")) + click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) + click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) + + +@click.command("transform-datasource-credentials", help="Transform datasource credentials.") +@click.option( + "--environment", prompt=True, help="the environment to transform datasource credentials", default="online" +) +def transform_datasource_credentials(environment: str): + """ + Transform datasource credentials + """ + try: + installer_manager = PluginInstaller() + plugin_migration = PluginMigration() + + notion_plugin_id = "langgenius/notion_datasource" + firecrawl_plugin_id = "langgenius/firecrawl_datasource" + jina_plugin_id = "langgenius/jina_datasource" + if environment == "online": + notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage] + firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage] + jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage] + else: + notion_plugin_unique_identifier = None + firecrawl_plugin_unique_identifier = None + jina_plugin_unique_identifier = None + oauth_credential_type = CredentialType.OAUTH2 + api_key_credential_type = CredentialType.API_KEY + + # deal notion credentials + deal_notion_count = 0 + notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() + if notion_credentials: + notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} + for notion_credential in notion_credentials: + tenant_id = notion_credential.tenant_id + if tenant_id not in notion_credentials_tenant_mapping: + notion_credentials_tenant_mapping[tenant_id] = [] + notion_credentials_tenant_mapping[tenant_id].append(notion_credential) + for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): + tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + if not tenant: + continue + try: + # check notion plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if notion_plugin_id not in installed_plugins_ids: + if notion_plugin_unique_identifier: + # install notion plugin + PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) + auth_count = 0 + for notion_tenant_credential in notion_tenant_credentials: + auth_count += 1 + # get credential oauth params + access_token = notion_tenant_credential.access_token + # notion info + notion_info = notion_tenant_credential.source_info + workspace_id = notion_info.get("workspace_id") + workspace_name = notion_info.get("workspace_name") + workspace_icon = notion_info.get("workspace_icon") + new_credentials = { + "integration_secret": encrypter.encrypt_token(tenant_id, access_token), + "workspace_id": workspace_id, + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + } + datasource_provider = DatasourceProvider( + provider="notion_datasource", + tenant_id=tenant_id, + plugin_id=notion_plugin_id, + auth_type=oauth_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url=workspace_icon or "default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_notion_count += 1 + except Exception as e: + click.echo( + click.style( + f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" + ) + ) + continue + db.session.commit() + # deal firecrawl credentials + deal_firecrawl_count = 0 + firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all() + if firecrawl_credentials: + firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} + for firecrawl_credential in firecrawl_credentials: + tenant_id = firecrawl_credential.tenant_id + if tenant_id not in firecrawl_credentials_tenant_mapping: + firecrawl_credentials_tenant_mapping[tenant_id] = [] + firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) + for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): + tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + if not tenant: + continue + try: + # check firecrawl plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if firecrawl_plugin_id not in installed_plugins_ids: + if firecrawl_plugin_unique_identifier: + # install firecrawl plugin + PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) + + auth_count = 0 + for firecrawl_tenant_credential in firecrawl_tenant_credentials: + auth_count += 1 + if not firecrawl_tenant_credential.credentials: + click.echo( + click.style( + f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.", + fg="yellow", + ) + ) + continue + # get credential api key + credentials_json = json.loads(firecrawl_tenant_credential.credentials) + api_key = credentials_json.get("config", {}).get("api_key") + base_url = credentials_json.get("config", {}).get("base_url") + new_credentials = { + "firecrawl_api_key": api_key, + "base_url": base_url, + } + datasource_provider = DatasourceProvider( + provider="firecrawl", + tenant_id=tenant_id, + plugin_id=firecrawl_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_firecrawl_count += 1 + except Exception as e: + click.echo( + click.style( + f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" + ) + ) + continue + db.session.commit() + # deal jina credentials + deal_jina_count = 0 + jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all() + if jina_credentials: + jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} + for jina_credential in jina_credentials: + tenant_id = jina_credential.tenant_id + if tenant_id not in jina_credentials_tenant_mapping: + jina_credentials_tenant_mapping[tenant_id] = [] + jina_credentials_tenant_mapping[tenant_id].append(jina_credential) + for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): + tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + if not tenant: + continue + try: + # check jina plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if jina_plugin_id not in installed_plugins_ids: + if jina_plugin_unique_identifier: + # install jina plugin + logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) + PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) + + auth_count = 0 + for jina_tenant_credential in jina_tenant_credentials: + auth_count += 1 + if not jina_tenant_credential.credentials: + click.echo( + click.style( + f"Skipping jina credential for tenant {tenant_id} due to missing credentials.", + fg="yellow", + ) + ) + continue + # get credential api key + credentials_json = json.loads(jina_tenant_credential.credentials) + api_key = credentials_json.get("config", {}).get("api_key") + new_credentials = { + "integration_secret": api_key, + } + datasource_provider = DatasourceProvider( + provider="jinareader", + tenant_id=tenant_id, + plugin_id=jina_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_jina_count += 1 + except Exception as e: + click.echo( + click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red") + ) + continue + db.session.commit() + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green")) + click.echo( + click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green") + ) + click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green")) + + +@click.command("migrate-data-for-plugin", help="Migrate data for plugin.") +def migrate_data_for_plugin(): + """ + Migrate data for plugin. + """ + click.echo(click.style("Starting migrate data for plugin.", fg="white")) + + PluginDataMigration.migrate() + + click.echo(click.style("Migrate data for plugin completed.", fg="green")) + + +@click.command("extract-plugins", help="Extract plugins.") +@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl") +@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10) +def extract_plugins(output_file: str, workers: int): + """ + Extract plugins. + """ + click.echo(click.style("Starting extract plugins.", fg="white")) + + PluginMigration.extract_plugins(output_file, workers) + + click.echo(click.style("Extract plugins completed.", fg="green")) + + +@click.command("extract-unique-identifiers", help="Extract unique identifiers.") +@click.option( + "--output_file", + prompt=True, + help="The file to store the extracted unique identifiers.", + default="unique_identifiers.json", +) +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +def extract_unique_plugins(output_file: str, input_file: str): + """ + Extract unique plugins. + """ + click.echo(click.style("Starting extract unique plugins.", fg="white")) + + PluginMigration.extract_unique_plugins_to_file(input_file, output_file) + + click.echo(click.style("Extract unique plugins completed.", fg="green")) + + +@click.command("install-plugins", help="Install plugins.") +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +@click.option( + "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" +) +@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) +def install_plugins(input_file: str, output_file: str, workers: int): + """ + Install plugins. + """ + click.echo(click.style("Starting install plugins.", fg="white")) + + PluginMigration.install_plugins(input_file, output_file, workers) + + click.echo(click.style("Install plugins completed.", fg="green")) + + +@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.") +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +@click.option( + "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" +) +@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) +def install_rag_pipeline_plugins(input_file, output_file, workers): + """ + Install rag pipeline plugins + """ + click.echo(click.style("Installing rag pipeline plugins", fg="yellow")) + plugin_migration = PluginMigration() + plugin_migration.install_rag_pipeline_plugins( + input_file, + output_file, + workers, + ) + click.echo(click.style("Installing rag pipeline plugins successfully", fg="green")) diff --git a/api/commands/retention.py b/api/commands/retention.py new file mode 100644 index 0000000000..5a91c1cc70 --- /dev/null +++ b/api/commands/retention.py @@ -0,0 +1,830 @@ +import datetime +import logging +import time +from typing import Any + +import click +import sqlalchemy as sa + +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup +from tasks.remove_app_and_related_data_task import delete_draft_variables_batch + +logger = logging.getLogger(__name__) + + +@click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.") +@click.option("--days", prompt=True, help="The days to clear free plan tenant expired logs.", default=30) +@click.option("--batch", prompt=True, help="The batch size to clear free plan tenant expired logs.", default=100) +@click.option( + "--tenant_ids", + prompt=True, + multiple=True, + help="The tenant ids to clear free plan tenant expired logs.", +) +def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[str]): + """ + Clear free plan tenant expired logs. + """ + click.echo(click.style("Starting clear free plan tenant expired logs.", fg="white")) + + ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids) + + click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) + + +@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") +@click.option( + "--before-days", + "--days", + default=30, + show_default=True, + type=click.IntRange(min=0), + help="Delete workflow runs created before N days ago.", +) +@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option( + "--dry-run", + is_flag=True, + help="Preview cleanup results without deleting any workflow run data.", +) +def clean_workflow_runs( + before_days: int, + batch_size: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + dry_run: bool, +): + """ + Clean workflow runs and related workflow data for free tenants. + """ + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + + if (from_days_ago is None) ^ (to_days_ago is None): + raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.") + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + raise click.UsageError("Choose either day offsets or explicit dates, not both.") + if from_days_ago <= to_days_ago: + raise click.UsageError("--from-days-ago must be greater than --to-days-ago.") + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + start_time = datetime.datetime.now(datetime.UTC) + click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) + + WorkflowRunCleanup( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + dry_run=dry_run, + ).run() + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Workflow run cleanup completed. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "archive-workflow-runs", + help="Archive workflow runs for paid plan tenants to S3-compatible storage.", +) +@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") +@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created at or after this timestamp (UTC if no timezone).", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created before this timestamp (UTC if no timezone).", +) +@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") +@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") +@click.option("--dry-run", is_flag=True, help="Preview without archiving.") +@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") +def archive_workflow_runs( + tenant_ids: str | None, + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + batch_size: int, + workers: int, + limit: int | None, + dry_run: bool, + delete_after_archive: bool, +): + """ + Archive workflow runs for paid plan tenants older than the specified days. + + This command archives the following tables to storage: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + + The workflow_runs and workflow_app_logs tables are preserved for UI listing. + """ + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + run_started_at = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting workflow run archiving at {run_started_at.isoformat()}.", + fg="white", + ) + ) + + if (start_from is None) ^ (end_before is None): + click.echo(click.style("start-from and end-before must be provided together.", fg="red")) + return + + if (from_days_ago is None) ^ (to_days_ago is None): + click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) + return + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) + return + if from_days_ago <= to_days_ago: + click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) + return + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if start_from and end_before and start_from >= end_before: + click.echo(click.style("start-from must be earlier than end-before.", fg="red")) + return + if workers < 1: + click.echo(click.style("workers must be at least 1.", fg="red")) + return + + archiver = WorkflowRunArchiver( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + workers=workers, + tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, + limit=limit, + dry_run=dry_run, + delete_after_archive=delete_after_archive, + ) + summary = archiver.run() + click.echo( + click.style( + f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="cyan", + ) + ) + + run_finished_at = datetime.datetime.now(datetime.UTC) + elapsed = run_finished_at - run_started_at + click.echo( + click.style( + f"Workflow run archiving completed. start={run_started_at.isoformat()} " + f"end={run_finished_at.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "restore-workflow-runs", + help="Restore archived workflow runs from S3-compatible storage.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to restore.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") +@click.option("--dry-run", is_flag=True, help="Preview without restoring.") +def restore_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + workers: int, + limit: int, + dry_run: bool, +): + """ + Restore an archived workflow run from storage to the database. + + This restores the following tables: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + """ + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch restore.") + if workers < 1: + raise click.BadParameter("workers must be at least 1") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", + fg="white", + ) + ) + + restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) + if run_id: + results = [restorer.restore_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = restorer.restore_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Restore completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +@click.command( + "delete-archived-workflow-runs", + help="Delete archived workflow runs from the database.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to delete.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") +@click.option("--dry-run", is_flag=True, help="Preview without deleting.") +def delete_archived_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + limit: int, + dry_run: bool, +): + """ + Delete archived workflow runs from the database. + """ + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch delete.") + + start_time = datetime.datetime.now(datetime.UTC) + target_desc = f"workflow run {run_id}" if run_id else "workflow runs" + click.echo( + click.style( + f"Starting delete of {target_desc} at {start_time.isoformat()}.", + fg="white", + ) + ) + + deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) + if run_id: + results = [deleter.delete_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = deleter.delete_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + for result in results: + if result.success: + click.echo( + click.style( + f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " + f"workflow run {result.run_id} (tenant={result.tenant_id})", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Failed to delete workflow run {result.run_id}: {result.error}", + fg="red", + ) + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Delete completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: + """ + Find draft variables that reference non-existent apps. + + Args: + batch_size: Maximum number of orphaned app IDs to return + + Returns: + List of app IDs that have draft variables but don't exist in the apps table + """ + query = """ + SELECT DISTINCT wdv.app_id + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + LIMIT :batch_size + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(query), {"batch_size": batch_size}) + return [row[0] for row in result] + + +def _count_orphaned_draft_variables() -> dict[str, Any]: + """ + Count orphaned draft variables by app, including associated file counts. + + Returns: + Dictionary with statistics about orphaned variables and files + """ + # Count orphaned variables by app + variables_query = """ + SELECT + wdv.app_id, + COUNT(*) as variable_count, + COUNT(wdv.file_id) as file_count + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + GROUP BY wdv.app_id + ORDER BY variable_count DESC + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(variables_query)) + orphaned_by_app = {} + total_files = 0 + + for row in result: + app_id, variable_count, file_count = row + orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count} + total_files += file_count + + total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values()) + app_count = len(orphaned_by_app) + + return { + "total_orphaned_variables": total_orphaned, + "total_orphaned_files": total_files, + "orphaned_app_count": app_count, + "orphaned_by_app": orphaned_by_app, + } + + +@click.command() +@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting") +@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)") +@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)") +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +def cleanup_orphaned_draft_variables( + dry_run: bool, + batch_size: int, + max_apps: int | None, + force: bool = False, +): + """ + Clean up orphaned draft variables from the database. + + This script finds and removes draft variables that belong to apps + that no longer exist in the database. + """ + logger = logging.getLogger(__name__) + + # Get statistics + stats = _count_orphaned_draft_variables() + + logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"]) + logger.info("Found %s associated offload files", stats["total_orphaned_files"]) + logger.info("Across %s non-existent apps", stats["orphaned_app_count"]) + + if stats["total_orphaned_variables"] == 0: + logger.info("No orphaned draft variables found. Exiting.") + return + + if dry_run: + logger.info("DRY RUN: Would delete the following:") + for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[ + :10 + ]: # Show top 10 + logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"]) + if len(stats["orphaned_by_app"]) > 10: + logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10) + return + + # Confirm deletion + if not force: + click.confirm( + f"Are you sure you want to delete {stats['total_orphaned_variables']} " + f"orphaned draft variables and {stats['total_orphaned_files']} associated files " + f"from {stats['orphaned_app_count']} apps?", + abort=True, + ) + + total_deleted = 0 + processed_apps = 0 + + while True: + if max_apps and processed_apps >= max_apps: + logger.info("Reached maximum app limit (%s). Stopping.", max_apps) + break + + orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10) + if not orphaned_app_ids: + logger.info("No more orphaned draft variables found.") + break + + for app_id in orphaned_app_ids: + if max_apps and processed_apps >= max_apps: + break + + try: + deleted_count = delete_draft_variables_batch(app_id, batch_size) + total_deleted += deleted_count + processed_apps += 1 + + logger.info("Deleted %s variables for app %s", deleted_count, app_id) + + except Exception: + logger.exception("Error processing app %s", app_id) + continue + + logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) + + +@click.command("clean-expired-messages", help="Clean expired messages.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=False, + default=None, + help="Lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=False, + default=None, + help="Upper bound (exclusive) for created_at.", +) +@click.option( + "--from-days-ago", + type=int, + default=None, + help="Relative lower bound in days ago (inclusive). Must be used with --before-days.", +) +@click.option( + "--before-days", + type=int, + default=None, + help="Relative upper bound in days ago (exclusive). Required for relative mode.", +) +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") +@click.option( + "--graceful-period", + default=21, + show_default=True, + help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", +) +@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting") +def clean_expired_messages( + batch_size: int, + graceful_period: int, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + from_days_ago: int | None, + before_days: int | None, + dry_run: bool, +): + """ + Clean expired messages and related data for tenants based on clean policy. + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + + start_at = time.perf_counter() + + try: + abs_mode = start_from is not None and end_before is not None + rel_mode = before_days is not None + + if abs_mode and rel_mode: + raise click.UsageError( + "Options are mutually exclusive: use either (--start-from,--end-before) " + "or (--from-days-ago,--before-days)." + ) + + if from_days_ago is not None and before_days is None: + raise click.UsageError("--from-days-ago must be used together with --before-days.") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.") + + if not abs_mode and not rel_mode: + raise click.UsageError( + "You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])." + ) + + if rel_mode: + assert before_days is not None + if before_days < 0: + raise click.UsageError("--before-days must be >= 0.") + if from_days_ago is not None: + if from_days_ago < 0: + raise click.UsageError("--from-days-ago must be >= 0.") + if from_days_ago <= before_days: + raise click.UsageError("--from-days-ago must be greater than --before-days.") + + # Create policy based on billing configuration + # NOTE: graceful_period will be ignored when billing is disabled. + policy = create_message_clean_policy(graceful_period_days=graceful_period) + + # Create and run the cleanup service + if abs_mode: + assert start_from is not None + assert end_before is not None + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + elif from_days_ago is None: + assert before_days is not None + service = MessagesCleanService.from_days( + policy=policy, + days=before_days, + batch_size=batch_size, + dry_run=dry_run, + ) + else: + assert before_days is not None + assert from_days_ago is not None + now = naive_utc_now() + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=now - datetime.timedelta(days=from_days_ago), + end_before=now - datetime.timedelta(days=before_days), + batch_size=batch_size, + dry_run=dry_run, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise + + click.echo(click.style("messages cleanup completed.", fg="green")) + + +@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.") +@click.option("--app-id", required=True, help="Application ID to export messages for.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Upper bound (exclusive) for created_at.", +) +@click.option( + "--filename", + required=True, + help="Base filename (relative path). Do not include suffix like .jsonl.gz.", +) +@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.") +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.") +@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.") +def export_app_messages( + app_id: str, + start_from: datetime.datetime | None, + end_before: datetime.datetime, + filename: str, + use_cloud_storage: bool, + batch_size: int, + dry_run: bool, +): + if start_from and start_from >= end_before: + raise click.UsageError("--start-from must be before --end-before.") + + from services.retention.conversation.message_export_service import AppMessageExportService + + try: + validated_filename = AppMessageExportService.validate_export_filename(filename) + except ValueError as e: + raise click.BadParameter(str(e), param_hint="--filename") from e + + click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green")) + start_at = time.perf_counter() + + try: + service = AppMessageExportService( + app_id=app_id, + end_before=end_before, + filename=validated_filename, + start_from=start_from, + batch_size=batch_size, + use_cloud_storage=use_cloud_storage, + dry_run=dry_run, + ) + stats = service.run() + + elapsed = time.perf_counter() - start_at + click.echo( + click.style( + f"export_app_messages: completed in {elapsed:.2f}s\n" + f" - Batches: {stats.batches}\n" + f" - Total messages: {stats.total_messages}\n" + f" - Messages with feedback: {stats.messages_with_feedback}\n" + f" - Total feedbacks: {stats.total_feedbacks}", + fg="green", + ) + ) + except Exception as e: + elapsed = time.perf_counter() - start_at + logger.exception("export_app_messages failed") + click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red")) + raise diff --git a/api/commands/storage.py b/api/commands/storage.py new file mode 100644 index 0000000000..fa890a855a --- /dev/null +++ b/api/commands/storage.py @@ -0,0 +1,755 @@ +import json + +import click +import sqlalchemy as sa + +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_storage import storage +from extensions.storage.opendal_storage import OpenDALStorage +from extensions.storage.storage_type import StorageType +from models.model import UploadFile + + +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +@click.command("clear-orphaned-file-records", help="Clear orphaned file records.") +def clear_orphaned_file_records(force: bool): + """ + Clear orphaned file records in the database. + """ + + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id"}, + {"type": "text", "table": "documents", "column": "data_source_info"}, + {"type": "text", "table": "document_segments", "column": "content"}, + {"type": "text", "table": "messages", "column": "answer"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs"}, + {"type": "text", "table": "conversations", "column": "introduction"}, + {"type": "text", "table": "conversations", "column": "system_instruction"}, + {"type": "text", "table": "accounts", "column": "avatar"}, + {"type": "text", "table": "apps", "column": "icon"}, + {"type": "text", "table": "sites", "column": "icon"}, + {"type": "json", "table": "messages", "column": "inputs"}, + {"type": "json", "table": "messages", "column": "message"}, + ] + + # notify user and ask for confirmation + click.echo( + click.style( + "This command will first find and delete orphaned file records from the message_files table,", fg="yellow" + ) + ) + click.echo( + click.style( + "and then it will find and delete orphaned file records in the following tables:", + fg="yellow", + ) + ) + for files_table in files_tables: + click.echo(click.style(f"- {files_table['table']}", fg="yellow")) + click.echo( + click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow") + ) + for ids_table in ids_tables: + click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow")) + click.echo("") + + click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) + click.echo( + click.style( + ( + "Since not all patterns have been fully tested, " + "please note that this command may delete unintended file records." + ), + fg="yellow", + ) + ) + click.echo( + click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow") + ) + click.echo( + click.style( + ( + "It is also recommended to run this during the maintenance window, " + "as this may cause high load on your instance." + ), + fg="yellow", + ) + ) + if not force: + click.confirm("Do you want to proceed?", abort=True) + + # start the cleanup process + click.echo(click.style("Starting orphaned file records cleanup.", fg="white")) + + # clean up the orphaned records in the message_files table where message_id doesn't exist in messages table + try: + click.echo( + click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white") + ) + query = ( + "SELECT mf.id, mf.message_id " + "FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id " + "WHERE m.id IS NULL" + ) + orphaned_message_files = [] + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) + + if orphaned_message_files: + click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white")) + for record in orphaned_message_files: + click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black")) + + if not force: + click.confirm( + ( + f"Do you want to proceed " + f"to delete all {len(orphaned_message_files)} orphaned message_files records?" + ), + abort=True, + ) + + click.echo(click.style("- Deleting orphaned message_files records", fg="white")) + query = "DELETE FROM message_files WHERE id IN :ids" + with db.engine.begin() as conn: + conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)}) + click.echo( + click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") + ) + else: + click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red")) + + # clean up the orphaned records in the rest of the *_files tables + try: + # fetch file id and keys from each table + all_files_in_tables = [] + for files_table in files_tables: + click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) + + # fetch referred table and columns + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + all_ids_in_tables = [] + for ids_table in ids_tables: + query = "" + match ids_table["type"]: + case "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) + ) + c = ids_table["column"] + query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + case "text": + t = ids_table["table"] + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case _: + pass + click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) + + except Exception as e: + click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) + return + + # find orphaned files + all_files = [file["id"] for file in all_files_in_tables] + all_ids = [file["id"] for file in all_ids_in_tables] + orphaned_files = list(set(all_files) - set(all_ids)) + if not orphaned_files: + click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green")) + return + click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white")) + for file in orphaned_files: + click.echo(click.style(f"- orphaned file id: {file}", fg="black")) + if not force: + click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True) + + # delete orphaned records for each file + try: + for files_table in files_tables: + click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) + query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" + with db.engine.begin() as conn: + conn.execute(sa.text(query), {"ids": tuple(orphaned_files)}) + except Exception as e: + click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) + return + click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green")) + + +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.") +def remove_orphaned_files_on_storage(force: bool): + """ + Remove orphaned files on the storage. + """ + + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "key_column": "key"}, + {"table": "tool_files", "key_column": "file_key"}, + ] + storage_paths = ["image_files", "tools", "upload_files"] + + # notify user and ask for confirmation + click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow")) + click.echo( + click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow") + ) + for files_table in files_tables: + click.echo(click.style(f"- {files_table['table']}", fg="yellow")) + click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow")) + for storage_path in storage_paths: + click.echo(click.style(f"- {storage_path}", fg="yellow")) + click.echo("") + + click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) + click.echo( + click.style( + "Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow" + ) + ) + click.echo( + click.style( + "Since not all patterns have been fully tested, please note that this command may delete unintended files.", + fg="yellow", + ) + ) + click.echo( + click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow") + ) + click.echo( + click.style( + ( + "It is also recommended to run this during the maintenance window, " + "as this may cause high load on your instance." + ), + fg="yellow", + ) + ) + if not force: + click.confirm("Do you want to proceed?", abort=True) + + # start the cleanup process + click.echo(click.style("Starting orphaned files cleanup.", fg="white")) + + # fetch file id and keys from each table + all_files_in_tables = [] + try: + for files_table in files_tables: + click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) + query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_files_in_tables.append(str(i[0])) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) + except Exception as e: + click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) + return + + all_files_on_storage = [] + for storage_path in storage_paths: + try: + click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white")) + files = storage.scan(path=storage_path, files=True, directories=False) + all_files_on_storage.extend(files) + except FileNotFoundError: + click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow")) + continue + except Exception as e: + click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red")) + continue + click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white")) + + # find orphaned files + orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables)) + if not orphaned_files: + click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green")) + return + click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white")) + for file in orphaned_files: + click.echo(click.style(f"- orphaned file: {file}", fg="black")) + if not force: + click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True) + + # delete orphaned files + removed_files = 0 + error_files = 0 + for file in orphaned_files: + try: + storage.delete(file) + removed_files += 1 + click.echo(click.style(f"- Removing orphaned file: {file}", fg="white")) + except Exception as e: + error_files += 1 + click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red")) + continue + if error_files == 0: + click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) + else: + click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) + + +@click.command("file-usage", help="Query file usages and show where files are referenced.") +@click.option("--file-id", type=str, default=None, help="Filter by file UUID.") +@click.option("--key", type=str, default=None, help="Filter by storage key.") +@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').") +@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).") +@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).") +@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.") +def file_usage( + file_id: str | None, + key: str | None, + src: str | None, + limit: int, + offset: int, + output_json: bool, +): + """ + Query file usages and show where files are referenced in the database. + + This command reuses the same reference checking logic as clear-orphaned-file-records + and displays detailed information about where each file is referenced. + """ + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"}, + {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"}, + {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"}, + {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"}, + {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"}, + {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"}, + {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "message", "pk_column": "id"}, + ] + + # Stream file usages with pagination to avoid holding all results in memory + paginated_usages = [] + total_count = 0 + + # First, build a mapping of file_id -> storage_key from the base tables + file_key_map = {} + for files_table in files_tables: + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}" + + # If filtering by key or file_id, verify it exists + if file_id and file_id not in file_key_map: + if output_json: + click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"})) + else: + click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red")) + return + + if key: + valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"} + matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes] + if not matching_file_ids: + if output_json: + click.echo(json.dumps({"error": f"Key {key} not found in base tables"})) + else: + click.echo(click.style(f"Key {key} not found in base tables.", fg="red")) + return + + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + + # For each reference table/column, find matching file IDs and record the references + for ids_table in ids_tables: + src_filter = f"{ids_table['table']}.{ids_table['column']}" + + # Skip if src filter doesn't match (use fnmatch for wildcard patterns) + if src: + if "%" in src or "_" in src: + import fnmatch + + # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?) + pattern = src.replace("%", "*").replace("_", "?") + if not fnmatch.fnmatch(src_filter, pattern): + continue + else: + if src_filter != src: + continue + + match ids_table["type"]: + case "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + case "text" | "json": + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + case _: + pass + + # Output results + if output_json: + result = { + "total": total_count, + "offset": offset, + "limit": limit, + "usages": paginated_usages, + } + click.echo(json.dumps(result, indent=2)) + else: + click.echo( + click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white") + ) + click.echo("") + + if not paginated_usages: + click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow")) + return + + # Print table header + click.echo( + click.style( + f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}", + fg="cyan", + ) + ) + click.echo(click.style("-" * 190, fg="white")) + + # Print each usage + for usage in paginated_usages: + click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}") + + # Show pagination info + if offset + limit < total_count: + click.echo("") + click.echo( + click.style( + f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white" + ) + ) + click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white")) + + +@click.command( + "migrate-oss", + help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).", +) +@click.option( + "--path", + "paths", + multiple=True, + help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files," + " tools, website_files, keyword_files, ops_trace", +) +@click.option( + "--source", + type=click.Choice(["local", "opendal"], case_sensitive=False), + default="opendal", + show_default=True, + help="Source storage type to read from", +) +@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists") +@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading") +@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts") +@click.option( + "--update-db/--no-update-db", + default=True, + help="Update upload_files.storage_type from source type to current storage after migration", +) +def migrate_oss( + paths: tuple[str, ...], + source: str, + overwrite: bool, + dry_run: bool, + force: bool, + update_db: bool, +): + """ + Copy all files under selected prefixes from a source storage + (Local filesystem or OpenDAL-backed) into the currently configured + destination storage backend, then optionally update DB records. + + Expected usage: set STORAGE_TYPE (and its credentials) to your target backend. + """ + # Ensure target storage is not local/opendal + if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL): + click.echo( + click.style( + "Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n" + "Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n" + "volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.", + fg="red", + ) + ) + return + + # Default paths if none specified + default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace") + path_list = list(paths) if paths else list(default_paths) + is_source_local = source.lower() == "local" + + click.echo(click.style("Preparing migration to target storage.", fg="yellow")) + click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white")) + if is_source_local: + src_root = dify_config.STORAGE_LOCAL_PATH + click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white")) + else: + click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white")) + click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white")) + click.echo("") + + if not force: + click.confirm("Proceed with migration?", abort=True) + + # Instantiate source storage + try: + if is_source_local: + src_root = dify_config.STORAGE_LOCAL_PATH + source_storage = OpenDALStorage(scheme="fs", root=src_root) + else: + source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME) + except Exception as e: + click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red")) + return + + total_files = 0 + copied_files = 0 + skipped_files = 0 + errored_files = 0 + copied_upload_file_keys: list[str] = [] + + for prefix in path_list: + click.echo(click.style(f"Scanning source path: {prefix}", fg="white")) + try: + keys = source_storage.scan(path=prefix, files=True, directories=False) + except FileNotFoundError: + click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow")) + continue + except NotImplementedError: + click.echo(click.style(" -> Source storage does not support scanning.", fg="red")) + return + except Exception as e: + click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red")) + continue + + click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white")) + + for key in keys: + total_files += 1 + + # check destination existence + if not overwrite: + try: + if storage.exists(key): + skipped_files += 1 + continue + except Exception as e: + # existence check failures should not block migration attempt + # but should be surfaced to user as a warning for visibility + click.echo( + click.style( + f" -> Warning: failed target existence check for {key}: {str(e)}", + fg="yellow", + ) + ) + + if dry_run: + copied_files += 1 + continue + + # read from source and write to destination + try: + data = source_storage.load_once(key) + except FileNotFoundError: + errored_files += 1 + click.echo(click.style(f" -> Missing on source: {key}", fg="yellow")) + continue + except Exception as e: + errored_files += 1 + click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red")) + continue + + try: + storage.save(key, data) + copied_files += 1 + if prefix == "upload_files": + copied_upload_file_keys.append(key) + except Exception as e: + errored_files += 1 + click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red")) + continue + + click.echo("") + click.echo(click.style("Migration summary:", fg="yellow")) + click.echo(click.style(f" Total: {total_files}", fg="white")) + click.echo(click.style(f" Copied: {copied_files}", fg="green")) + click.echo(click.style(f" Skipped: {skipped_files}", fg="white")) + if errored_files: + click.echo(click.style(f" Errors: {errored_files}", fg="red")) + + if dry_run: + click.echo(click.style("Dry-run complete. No changes were made.", fg="green")) + return + + if errored_files: + click.echo( + click.style( + "Some files failed to migrate. Review errors above before updating DB records.", + fg="yellow", + ) + ) + if update_db and not force: + if not click.confirm("Proceed to update DB storage_type despite errors?", default=False): + update_db = False + + # Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files) + if update_db: + if not copied_upload_file_keys: + click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow")) + else: + try: + source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL + updated = ( + db.session.query(UploadFile) + .where( + UploadFile.storage_type == source_storage_type, + UploadFile.key.in_(copied_upload_file_keys), + ) + .update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False) + ) + db.session.commit() + click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green")) + except Exception as e: + db.session.rollback() + click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) diff --git a/api/commands/system.py b/api/commands/system.py new file mode 100644 index 0000000000..604f0e34d0 --- /dev/null +++ b/api/commands/system.py @@ -0,0 +1,204 @@ +import logging + +import click +import sqlalchemy as sa +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from events.app_event import app_was_created +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.db_migration_lock import DbMigrationAutoRenewLock +from libs.rsa import generate_key_pair +from models import Tenant +from models.model import App, AppMode, Conversation +from models.provider import Provider, ProviderModel + +logger = logging.getLogger(__name__) + +DB_UPGRADE_LOCK_TTL_SECONDS = 60 + + +@click.command( + "reset-encrypt-key-pair", + help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " + "After the reset, all LLM credentials will become invalid, " + "requiring re-entry." + "Only support SELF_HOSTED mode.", +) +@click.confirmation_option( + prompt=click.style( + "Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red" + ) +) +def reset_encrypt_key_pair(): + """ + Reset the encrypted key pair of workspace for encrypt LLM credentials. + After the reset, all LLM credentials will become invalid, requiring re-entry. + Only support SELF_HOSTED mode. + """ + if dify_config.EDITION != "SELF_HOSTED": + click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) + return + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + tenants = session.query(Tenant).all() + for tenant in tenants: + if not tenant: + click.echo(click.style("No workspaces found. Run /install first.", fg="red")) + return + + tenant.encrypt_public_key = generate_key_pair(tenant.id) + + session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() + session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() + + click.echo( + click.style( + f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", + fg="green", + ) + ) + + +@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") +def convert_to_agent_apps(): + """ + Convert Agent Assistant to Agent App. + """ + click.echo(click.style("Starting convert to agent apps.", fg="green")) + + proceeded_app_ids = [] + + while True: + # fetch first 1000 apps + sql_query = """SELECT a.id AS id FROM apps a + INNER JOIN app_model_configs am ON a.app_model_config_id=am.id + WHERE a.mode = 'chat' + AND am.agent_mode is not null + AND ( + am.agent_mode like '%"strategy": "function_call"%' + OR am.agent_mode like '%"strategy": "react"%' + ) + AND ( + am.agent_mode like '{"enabled": true%' + OR am.agent_mode like '{"max_iteration": %' + ) ORDER BY a.created_at DESC LIMIT 1000 + """ + + with db.engine.begin() as conn: + rs = conn.execute(sa.text(sql_query)) + + apps = [] + for i in rs: + app_id = str(i.id) + if app_id not in proceeded_app_ids: + proceeded_app_ids.append(app_id) + app = db.session.query(App).where(App.id == app_id).first() + if app is not None: + apps.append(app) + + if len(apps) == 0: + break + + for app in apps: + click.echo(f"Converting app: {app.id}") + + try: + app.mode = AppMode.AGENT_CHAT + db.session.commit() + + # update conversation mode to agent + db.session.query(Conversation).where(Conversation.app_id == app.id).update( + {Conversation.mode: AppMode.AGENT_CHAT} + ) + + db.session.commit() + click.echo(click.style(f"Converted app: {app.id}", fg="green")) + except Exception as e: + click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red")) + + click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green")) + + +@click.command("upgrade-db", help="Upgrade the database") +def upgrade_db(): + click.echo("Preparing database migration...") + lock = DbMigrationAutoRenewLock( + redis_client=redis_client, + name="db_upgrade_lock", + ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS, + logger=logger, + log_context="db_migration", + ) + if lock.acquire(blocking=False): + migration_succeeded = False + try: + click.echo(click.style("Starting database migration.", fg="green")) + + # run db migration + import flask_migrate + + flask_migrate.upgrade() + + migration_succeeded = True + click.echo(click.style("Database migration successful!", fg="green")) + + except Exception as e: + logger.exception("Failed to execute database migration") + click.echo(click.style(f"Database migration failed: {e}", fg="red")) + raise SystemExit(1) + finally: + status = "successful" if migration_succeeded else "failed" + lock.release_safely(status=status) + else: + click.echo("Database migration skipped") + + +@click.command("fix-app-site-missing", help="Fix app related site missing issue.") +def fix_app_site_missing(): + """ + Fix app related site missing issue. + """ + click.echo(click.style("Starting fix for missing app-related sites.", fg="green")) + + failed_app_ids = [] + while True: + sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id +where sites.id is null limit 1000""" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(sql)) + + processed_count = 0 + for i in rs: + processed_count += 1 + app_id = str(i.id) + + if app_id in failed_app_ids: + continue + + try: + app = db.session.query(App).where(App.id == app_id).first() + if not app: + logger.info("App %s not found", app_id) + continue + + tenant = app.tenant + if tenant: + accounts = tenant.get_accounts() + if not accounts: + logger.info("Fix failed for app %s", app.id) + continue + + account = accounts[0] + logger.info("Fixing missing site for app %s", app.id) + app_was_created.send(app, account=account) + except Exception: + failed_app_ids.append(app_id) + click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red")) + logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id) + continue + + if not processed_count: + break + + click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) diff --git a/api/commands/vector.py b/api/commands/vector.py new file mode 100644 index 0000000000..4df194026b --- /dev/null +++ b/api/commands/vector.py @@ -0,0 +1,466 @@ +import json + +import click +from flask import current_app +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment +from models.dataset import Document as DatasetDocument +from models.model import App, AppAnnotationSetting, MessageAnnotation + + +@click.command("vdb-migrate", help="Migrate vector db.") +@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") +def vdb_migrate(scope: str): + if scope in {"knowledge", "all"}: + migrate_knowledge_vector_database() + if scope in {"annotation", "all"}: + migrate_annotation_vector_database() + + +def migrate_annotation_vector_database(): + """ + Migrate annotation datas to target vector database . + """ + click.echo(click.style("Starting annotation data migration.", fg="green")) + create_count = 0 + skipped_count = 0 + total_count = 0 + page = 1 + while True: + try: + # get apps info + per_page = 50 + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + apps = ( + session.query(App) + .where(App.status == "normal") + .order_by(App.created_at.desc()) + .limit(per_page) + .offset((page - 1) * per_page) + .all() + ) + if not apps: + break + except SQLAlchemyError: + raise + + page += 1 + for app in apps: + total_count = total_count + 1 + click.echo( + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." + ) + try: + click.echo(f"Creating app annotation index: {app.id}") + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + app_annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() + ) + + if not app_annotation_setting: + skipped_count = skipped_count + 1 + click.echo(f"App annotation setting disabled: {app.id}") + continue + # get dataset_collection_binding info + dataset_collection_binding = ( + session.query(DatasetCollectionBinding) + .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .first() + ) + if not dataset_collection_binding: + click.echo(f"App annotation collection binding not found: {app.id}") + continue + annotations = session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) + ).all() + dataset = Dataset( + id=app.id, + tenant_id=app.tenant_id, + indexing_technique="high_quality", + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id, + ) + documents = [] + if annotations: + for annotation in annotations: + document = Document( + page_content=annotation.question_text, + metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, + ) + documents.append(document) + + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + click.echo(f"Migrating annotations for app: {app.id}.") + + try: + vector.delete() + click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green")) + except Exception as e: + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) + raise e + if documents: + try: + click.echo( + click.style( + f"Creating vector index with {len(documents)} annotations for app {app.id}.", + fg="green", + ) + ) + vector.create(documents) + click.echo(click.style(f"Created vector index for app {app.id}.", fg="green")) + except Exception as e: + click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) + raise e + click.echo(f"Successfully migrated app annotation {app.id}.") + create_count += 1 + except Exception as e: + click.echo( + click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red") + ) + continue + + click.echo( + click.style( + f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.", + fg="green", + ) + ) + + +def migrate_knowledge_vector_database(): + """ + Migrate vector database datas to target vector database . + """ + click.echo(click.style("Starting vector database migration.", fg="green")) + create_count = 0 + skipped_count = 0 + total_count = 0 + vector_type = dify_config.VECTOR_STORE + upper_collection_vector_types = { + VectorType.MILVUS, + VectorType.PGVECTOR, + VectorType.VASTBASE, + VectorType.RELYT, + VectorType.WEAVIATE, + VectorType.ORACLE, + VectorType.ELASTICSEARCH, + VectorType.OPENGAUSS, + VectorType.TABLESTORE, + VectorType.MATRIXONE, + } + lower_collection_vector_types = { + VectorType.ANALYTICDB, + VectorType.CHROMA, + VectorType.MYSCALE, + VectorType.PGVECTO_RS, + VectorType.TIDB_VECTOR, + VectorType.OPENSEARCH, + VectorType.TENCENT, + VectorType.BAIDU, + VectorType.VIKINGDB, + VectorType.UPSTASH, + VectorType.COUCHBASE, + VectorType.OCEANBASE, + } + page = 1 + while True: + try: + stmt = ( + select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + ) + + datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) + if not datasets.items: + break + except SQLAlchemyError: + raise + + page += 1 + for dataset in datasets: + total_count = total_count + 1 + click.echo( + f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." + ) + try: + click.echo(f"Creating dataset vector database index: {dataset.id}") + if dataset.index_struct_dict: + if dataset.index_struct_dict["type"] == vector_type: + skipped_count = skipped_count + 1 + continue + collection_name = "" + dataset_id = dataset.id + if vector_type in upper_collection_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + elif vector_type == VectorType.QDRANT: + if dataset.collection_binding_id: + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .where(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) + if dataset_collection_binding: + collection_name = dataset_collection_binding.collection_name + else: + raise ValueError("Dataset Collection Binding not found") + else: + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + + elif vector_type in lower_collection_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + else: + raise ValueError(f"Vector store {vector_type} is not supported.") + + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} + dataset.index_struct = json.dumps(index_struct_dict) + vector = Vector(dataset) + click.echo(f"Migrating dataset {dataset.id}.") + + try: + vector.delete() + click.echo( + click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green") + ) + except Exception as e: + click.echo( + click.style( + f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" + ) + ) + raise e + + dataset_documents = db.session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + + documents = [] + segments_count = 0 + for dataset_document in dataset_documents: + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + ).all() + + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == "hierarchical_model": + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + + documents.append(document) + segments_count = segments_count + 1 + + if documents: + try: + click.echo( + click.style( + f"Creating vector index with {len(documents)} documents of {segments_count}" + f" segments for dataset {dataset.id}.", + fg="green", + ) + ) + all_child_documents = [] + for doc in documents: + if doc.children: + all_child_documents.extend(doc.children) + vector.create(documents) + if all_child_documents: + vector.create(all_child_documents) + click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green")) + except Exception as e: + click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) + raise e + db.session.add(dataset) + db.session.commit() + click.echo(f"Successfully migrated dataset {dataset.id}.") + create_count += 1 + except Exception as e: + db.session.rollback() + click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red")) + continue + + click.echo( + click.style( + f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green" + ) + ) + + +@click.command("add-qdrant-index", help="Add Qdrant index.") +@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.") +def add_qdrant_index(field: str): + click.echo(click.style("Starting Qdrant index creation.", fg="green")) + + create_count = 0 + + try: + bindings = db.session.query(DatasetCollectionBinding).all() + if not bindings: + click.echo(click.style("No dataset collection bindings found.", fg="red")) + return + import qdrant_client + from qdrant_client.http.exceptions import UnexpectedResponse + from qdrant_client.http.models import PayloadSchemaType + + from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig + + for binding in bindings: + if dify_config.QDRANT_URL is None: + raise ValueError("Qdrant URL is required.") + qdrant_config = QdrantConfig( + endpoint=dify_config.QDRANT_URL, + api_key=dify_config.QDRANT_API_KEY, + root_path=current_app.root_path, + timeout=dify_config.QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.QDRANT_GRPC_PORT, + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + ) + try: + params = qdrant_config.to_qdrant_params() + # Check the type before using + if isinstance(params, PathQdrantParams): + # PathQdrantParams case + client = qdrant_client.QdrantClient(path=params.path) + else: + # UrlQdrantParams case - params is UrlQdrantParams + client = qdrant_client.QdrantClient( + url=params.url, + api_key=params.api_key, + timeout=int(params.timeout), + verify=params.verify, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + ) + # create payload index + client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) + create_count += 1 + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red")) + continue + # Some other error occurred, so re-raise the exception + else: + click.echo( + click.style( + f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red" + ) + ) + + except Exception: + click.echo(click.style("Failed to create Qdrant client.", fg="red")) + + click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) + + +@click.command("old-metadata-migration", help="Old metadata migration.") +def old_metadata_migration(): + """ + Old metadata migration. + """ + click.echo(click.style("Starting old metadata migration.", fg="green")) + + page = 1 + while True: + try: + stmt = ( + select(DatasetDocument) + .where(DatasetDocument.doc_metadata.is_not(None)) + .order_by(DatasetDocument.created_at.desc()) + ) + documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) + except SQLAlchemyError: + raise + if not documents: + break + for document in documents: + if document.doc_metadata: + doc_metadata = document.doc_metadata + for key in doc_metadata: + for field in BuiltInField: + if field.value == key: + break + else: + dataset_metadata = ( + db.session.query(DatasetMetadata) + .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) + .first() + ) + if not dataset_metadata: + dataset_metadata = DatasetMetadata( + tenant_id=document.tenant_id, + dataset_id=document.dataset_id, + name=key, + type="string", + created_by=document.created_by, + ) + db.session.add(dataset_metadata) + db.session.flush() + dataset_metadata_binding = DatasetMetadataBinding( + tenant_id=document.tenant_id, + dataset_id=document.dataset_id, + metadata_id=dataset_metadata.id, + document_id=document.id, + created_by=document.created_by, + ) + db.session.add(dataset_metadata_binding) + else: + dataset_metadata_binding = ( + db.session.query(DatasetMetadataBinding) # type: ignore + .where( + DatasetMetadataBinding.dataset_id == document.dataset_id, + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == dataset_metadata.id, + ) + .first() + ) + if not dataset_metadata_binding: + dataset_metadata_binding = DatasetMetadataBinding( + tenant_id=document.tenant_id, + dataset_id=document.dataset_id, + metadata_id=dataset_metadata.id, + document_id=document.id, + created_by=document.created_by, + ) + db.session.add(dataset_metadata_binding) + db.session.commit() + page += 1 + click.echo(click.style("Old metadata migration completed.", fg="green")) diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index eda6345e14..f8447c6979 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -18,3 +18,7 @@ class EnterpriseFeatureConfig(BaseSettings): description="Allow customization of the enterprise logo.", default=False, ) + + ENTERPRISE_REQUEST_TIMEOUT: int = Field( + ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 54303b2482..ddad7f40ca 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource): console_ns.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code="max_keys_exceeded", + custom="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 660a4d5aea..0f29627746 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from controllers.common import fields from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError @@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index f6032a8e49..9e3fb3a90b 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager -from extensions.ext_database import db as global_db DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -57,7 +56,7 @@ class ToolFileApi(Resource): raise Forbidden("Invalid request.") try: - tool_file_manager = ToolFileManager(engine=global_db.engine) + tool_file_manager = ToolFileManager() stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id( file_id, ) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 562f5e33cc..abcaa0e240 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from flask_restx import Resource from controllers.common.fields import Parameters @@ -33,14 +35,14 @@ class AppParameterApi(Resource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index dc8da025d4..5a1d28ea1d 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,8 +1,9 @@ import json +from contextlib import ExitStack from typing import Self from uuid import UUID -from flask import request +from flask import request, send_file from flask_restx import marshal from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import desc, select @@ -100,6 +101,15 @@ class DocumentListQuery(BaseModel): status: str | None = Field(default=None, description="Document status filter") +DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 + + +class DocumentBatchDownloadZipPayload(BaseModel): + """Request payload for bulk downloading uploaded documents as a ZIP archive.""" + + document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) + + register_enum_models(service_api_ns, RetrievalMethod) register_schema_models( @@ -109,6 +119,7 @@ register_schema_models( DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery, + DocumentBatchDownloadZipPayload, Rule, PreProcessingRule, Segmentation, @@ -540,6 +551,46 @@ class DocumentListApi(DatasetApiResource): return response +@service_api_ns.route("/datasets//documents/download-zip") +class DocumentBatchDownloadZipApi(DatasetApiResource): + """Download multiple uploaded-file documents as a single ZIP archive.""" + + @service_api_ns.expect(service_api_ns.models[DocumentBatchDownloadZipPayload.__name__]) + @service_api_ns.doc("download_documents_as_zip") + @service_api_ns.doc(description="Download selected uploaded documents as a single ZIP archive") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "ZIP archive generated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Document or dataset not found", + } + ) + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id, dataset_id): + payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {}) + + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=str(dataset_id), + document_ids=[str(document_id) for document_id in payload.document_ids], + tenant_id=str(tenant_id), + current_user=current_user, + ) + + with ExitStack() as stack: + zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files)) + response = send_file( + zip_path, + mimetype="application/zip", + as_attachment=True, + download_name=download_name, + ) + cleanup = stack.pop_all() + response.call_on_close(cleanup.close) + return response + + @service_api_ns.route("/datasets//documents//indexing-status") class DocumentIndexingStatusApi(DatasetApiResource): @service_api_ns.doc("get_document_indexing_status") @@ -600,6 +651,35 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data +@service_api_ns.route("/datasets//documents//download") +class DocumentDownloadApi(DatasetApiResource): + """Return a signed download URL for a document's original uploaded file.""" + + @service_api_ns.doc("get_document_download_url") + @service_api_ns.doc(description="Get a signed download URL for a document's original uploaded file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Download URL generated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Document or upload file not found", + } + ) + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def get(self, tenant_id, dataset_id, document_id): + dataset = self.get_dataset(str(dataset_id), str(tenant_id)) + document = DocumentService.get_document(dataset.id, str(document_id)) + + if not document: + raise NotFound("Document not found.") + + if document.tenant_id != str(tenant_id): + raise Forbidden("No permission.") + + return {"url": DocumentService.get_document_download_url(document)} + + @service_api_ns.route("/datasets//documents/") class DocumentApi(DatasetApiResource): METADATA_CHOICES = {"all", "only", "without"} diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 62ea532eac..25bbedce54 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,5 @@ import logging +from typing import Any, cast from flask import request from flask_restx import Resource @@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index bbae1ce266..2b60691949 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: - raise NotCompletionAppError() + raise NotChatAppError() message_id = str(message_id) diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index e925d6dd52..7d1b11c008 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,10 +1,13 @@ +from collections.abc import Mapping +from typing import Any + from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod - def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None: + def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None: sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None @@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager: if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( type=sensitive_word_avoidance_dict.get("type"), - config=sensitive_word_avoidance_dict.get("config"), + config=sensitive_word_avoidance_dict.get("config", {}), ) else: return None diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 9b981dfc09..10db380d1f 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,10 +1,13 @@ +from typing import Any, cast + from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity from core.agent.prompt.template import REACT_PROMPT_TEMPLATES +from models.model import AppModelConfigDict class AgentConfigManager: @classmethod - def convert(cls, config: dict) -> AgentEntity | None: + def convert(cls, config: AppModelConfigDict) -> AgentEntity | None: """ Convert model config to model config @@ -28,17 +31,17 @@ class AgentConfigManager: agent_tools = [] for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: + tool_dict = cast(dict[str, Any], tool) + if len(tool_dict) >= 4: + if "enabled" not in tool_dict or not tool_dict["enabled"]: continue agent_tool_properties = { - "provider_type": tool["provider_type"], - "provider_id": tool["provider_id"], - "tool_name": tool["tool_name"], - "tool_parameters": tool.get("tool_parameters", {}), - "credential_id": tool.get("credential_id", None), + "provider_type": tool_dict["provider_type"], + "provider_id": tool_dict["provider_id"], + "tool_name": tool_dict["tool_name"], + "tool_parameters": tool_dict.get("tool_parameters", {}), + "credential_id": tool_dict.get("credential_id", None), } agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties)) @@ -47,7 +50,8 @@ class AgentConfigManager: "react_router", "router", }: - agent_prompt = agent_dict.get("prompt", None) or {} + agent_prompt_raw = agent_dict.get("prompt", None) + agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {} # check model mode model_mode = config.get("model", {}).get("mode", "completion") if model_mode == "completion": @@ -75,7 +79,7 @@ class AgentConfigManager: strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get("max_iteration", 10), + max_iteration=cast(int, agent_dict.get("max_iteration", 10)), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index aacafb2dad..70f43b2c83 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,5 +1,5 @@ import uuid -from typing import Literal, cast +from typing import Any, Literal, cast from core.app.app_config.entities import ( DatasetEntity, @@ -8,13 +8,13 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService class DatasetConfigManager: @classmethod - def convert(cls, config: dict) -> DatasetEntity | None: + def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None: """ Convert model config to model config @@ -25,11 +25,15 @@ class DatasetConfigManager: datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) for dataset in datasets.get("datasets", []): + if not isinstance(dataset, dict): + continue keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != "dataset": continue dataset = dataset["dataset"] + if not isinstance(dataset, dict): + continue if "enabled" not in dataset or not dataset["enabled"]: continue @@ -47,15 +51,14 @@ class DatasetConfigManager: agent_dict = config.get("agent_mode", {}) for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) == 1: + if len(tool) == 1: # old standard key = list(tool.keys())[0] if key != "dataset": continue - tool_item = tool[key] + tool_item = cast(dict[str, Any], tool)[key] if "enabled" not in tool_item or not tool_item["enabled"]: continue diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index e4e750c735..0929f52e33 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID class ModelConfigManager: @classmethod - def convert(cls, config: dict) -> ModelConfigEntity: + def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity: """ Convert model config to model config @@ -22,7 +23,7 @@ class ModelConfigManager: if not model_config: raise ValueError("model is required") - completion_params = model_config.get("completion_params") + completion_params = model_config.get("completion_params") or {} stop = [] if "stop" in completion_params: stop = completion_params["stop"] diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 01b9601965..b7073898d6 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,3 +1,5 @@ +from typing import Any + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -6,12 +8,12 @@ from core.app.app_config.entities import ( ) from core.prompt.simple_prompt_transform import ModelMode from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict class PromptTemplateConfigManager: @classmethod - def convert(cls, config: dict) -> PromptTemplateEntity: + def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity: if not config.get("prompt_type"): raise ValueError("prompt_type is required") @@ -40,14 +42,15 @@ class PromptTemplateConfigManager: advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: - completion_prompt_template_params = { + completion_prompt_template_params: dict[str, Any] = { "prompt": completion_prompt_config["prompt"]["text"], } - if "conversation_histories_role" in completion_prompt_config: + conv_role = completion_prompt_config.get("conversation_histories_role") + if conv_role: completion_prompt_template_params["role_prefix"] = { - "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], - "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], + "user": conv_role["user_prefix"], + "assistant": conv_role["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 157e5d8bc0..8de1224a89 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,8 +1,10 @@ import re +from typing import cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( [ @@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( class BasicVariablesConfigManager: @classmethod - def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: """ Convert model config to model config @@ -51,7 +53,9 @@ class BasicVariablesConfigManager: external_data_variables.append( ExternalDataVariableEntity( - variable=variable["variable"], type=variable["type"], config=variable["config"] + variable=variable["variable"], + type=variable.get("type", ""), + config=variable.get("config", {}), ) ) elif variable_type in { @@ -64,10 +68,10 @@ class BasicVariablesConfigManager: variable = variables[variable_type] variable_entities.append( VariableEntity( - type=variable_type, - variable=variable.get("variable"), + type=cast(VariableEntityType, variable_type), + variable=variable["variable"], description=variable.get("description") or "", - label=variable.get("label"), + label=variable["label"], required=variable.get("required", False), max_length=variable.get("max_length"), options=variable.get("options") or [], diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index f26351d93e..ac21577d57 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig): app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str - app_model_config_dict: dict + app_model_config_dict: dict[str, Any] model: ModelConfigEntity prompt_template: PromptTemplateEntity dataset: DatasetEntity | None = None diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index b38dfdfc1f..66037696af 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -138,20 +138,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query = self.application_generate_entity.query # moderation - if self.handle_input_moderation( + stop, new_inputs, new_query = self.handle_input_moderation( app_record=self._app, app_generate_entity=self.application_generate_entity, inputs=inputs, query=query, message_id=self.message.id, - ): + ) + if stop: return + self.application_generate_entity.inputs = new_inputs + self.application_generate_entity.query = new_query + system_inputs.query = new_query + # annotation reply if self.handle_annotation_reply( app_record=self._app, message=self.message, - query=query, + query=new_query, app_generate_entity=self.application_generate_entity, ): return @@ -163,7 +168,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # init variable pool variable_pool = VariablePool( system_variables=system_inputs, - user_inputs=inputs, + user_inputs=new_inputs, environment_variables=self._workflow.environment_variables, # Based on the definition of `Variable`, # `VariableBase` instances can be safely used as `Variable` since they are compatible. @@ -240,10 +245,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): inputs: Mapping[str, Any], query: str, message_id: str, - ) -> bool: + ) -> tuple[bool, Mapping[str, Any], str]: try: # process sensitive_word_avoidance - _, inputs, query = self.moderation_for_inputs( + _, new_inputs, new_query = self.moderation_for_inputs( app_id=app_record.id, tenant_id=app_generate_entity.app_config.tenant_id, app_generate_entity=app_generate_entity, @@ -253,9 +258,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) except ModerationError as e: self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) - return True + return True, inputs, query - return False + return False, new_inputs, new_query def handle_annotation_reply( self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index fbd5060b8c..a1cb375e24 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): graph_runtime_state=validated_state, ) + yield from self._handle_advanced_chat_message_end_event( + QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state + ) yield workflow_finish_resp - self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_partial_success_event( self, @@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): exceptions_count=event.exceptions_count, ) + yield from self._handle_advanced_chat_message_end_event( + QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state + ) yield workflow_finish_resp def _handle_workflow_paused_event( @@ -854,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): yield from self._handle_workflow_paused_event(event) break + case QueueWorkflowSucceededEvent(): + yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager) + break + + case QueueWorkflowPartialSuccessEvent(): + yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager) + break + case QueueStopEvent(): yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager) break diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 801619ddbc..f0d81e0c59 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config @@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( @@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): + def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict: """ Validate for agent chat app model config @@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) @classmethod def validate_agent_mode_and_set_defaults( diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 4b6720a3c3..5f087f6066 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation class ChatAppConfig(EasyUIBasedAppConfig): @@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> ChatAppConfig: """ Convert app model config to chat app config @@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for chat app model config @@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 23546a47bb..f63b38fc86 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index eb1902f12e..f49e7b8b5e 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict class CompletionAppConfig(EasyUIBasedAppConfig): @@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config( - cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None ) -> CompletionAppConfig: """ Convert app model config to completion app config @@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( @@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for completion app model config @@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index e8b0e4f179..002b914ef1 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] - completion_params = model_dict.get("completion_params") + completion_params = model_dict.get("completion_params", {}) completion_params["temperature"] = 0.9 model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index ac05172945..56a4519879 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 1fa782eb6c..b530fe1ce4 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator from threading import Thread -from typing import Union, cast +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -44,14 +44,13 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from core.app.task_pipeline.message_file_utils import prepare_file_dict from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_manager import ModelInstance from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.tools.signature import sign_tool_file -from dify_graph.file import helpers as file_helpers from dify_graph.file.enums import FileTransferMethod from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from dify_graph.model_runtime.entities.message_entities import ( @@ -219,14 +218,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech")) if ( text_to_speech_dict and text_to_speech_dict.get("autoPlay") == "enabled" and text_to_speech_dict.get("enabled") ): publisher = AppGeneratorTTSPublisher( - tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None) + tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None) ) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: @@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): """ self._task_state.metadata.usage = self._task_state.llm_result.usage metadata_dict = self._task_state.metadata.model_dump() + + # Fetch files associated with this message + files = None + with Session(db.engine, expire_on_commit=False) as session: + message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all() + + if message_files: + # Fetch all required UploadFile objects in a single query to avoid N+1 problem + upload_file_ids = list( + dict.fromkeys( + mf.upload_file_id + for mf in message_files + if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id + ) + ) + upload_files_map = {} + if upload_file_ids: + upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all() + upload_files_map = {uf.id: uf for uf in upload_files} + + files_list = [] + for message_file in message_files: + file_dict = prepare_file_dict(message_file, upload_files_map) + files_list.append(file_dict) + + files = files_list or None + return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, metadata=metadata_dict, + files=files, ) - def _record_files(self): - with Session(db.engine, expire_on_commit=False) as session: - message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all() - if not message_files: - return None - - files_list = [] - upload_file_ids = [ - mf.upload_file_id - for mf in message_files - if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id - ] - upload_files_map = {} - if upload_file_ids: - upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all() - upload_files_map = {uf.id: uf for uf in upload_files} - - for message_file in message_files: - upload_file = None - if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id: - upload_file = upload_files_map.get(message_file.upload_file_id) - - url = None - filename = "file" - mime_type = "application/octet-stream" - size = 0 - extension = "" - - if message_file.transfer_method == FileTransferMethod.REMOTE_URL: - url = message_file.url - if message_file.url: - filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params - elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: - if upload_file: - url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) - filename = upload_file.name - mime_type = upload_file.mime_type or "application/octet-stream" - size = upload_file.size or 0 - extension = f".{upload_file.extension}" if upload_file.extension else "" - elif message_file.upload_file_id: - # Fallback: generate URL even if upload_file not found - url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: - # For tool files, use URL directly if it's HTTP, otherwise sign it - if message_file.url.startswith("http"): - url = message_file.url - filename = message_file.url.split("/")[-1].split("?")[0] - else: - # Extract tool file id and extension from URL - url_parts = message_file.url.split("/") - if url_parts: - file_part = url_parts[-1].split("?")[0] # Remove query params first - # Use rsplit to correctly handle filenames with multiple dots - if "." in file_part: - tool_file_id, ext = file_part.rsplit(".", 1) - extension = f".{ext}" - else: - tool_file_id = file_part - extension = ".bin" - url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) - filename = file_part - - transfer_method_value = message_file.transfer_method - remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" - file_dict = { - "related_id": message_file.id, - "extension": extension, - "filename": filename, - "size": size, - "mime_type": mime_type, - "transfer_method": transfer_method_value, - "type": message_file.type, - "url": url or "", - "upload_file_id": message_file.upload_file_id or message_file.id, - "remote_url": remote_url, - } - files_list.append(file_dict) - return files_list or None - def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: """ Agent message to stream response. diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index cc4f97ad94..536ab02eae 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -1,7 +1,6 @@ import hashlib import logging -import time -from threading import Thread +from threading import Thread, Timer from typing import Union from flask import Flask, current_app @@ -96,9 +95,9 @@ class MessageCycleManager: if auto_generate_conversation_name and is_first_message: # start generate thread # time.sleep not block other logic - time.sleep(1) - thread = Thread( - target=self._generate_conversation_name_worker, + thread = Timer( + 1, + self._generate_conversation_name_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore "conversation_id": conversation_id, diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py new file mode 100644 index 0000000000..843e9eea30 --- /dev/null +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -0,0 +1,76 @@ +from core.tools.signature import sign_tool_file +from dify_graph.file import helpers as file_helpers +from dify_graph.file.enums import FileTransferMethod +from models.model import MessageFile, UploadFile + +MAX_TOOL_FILE_EXTENSION_LENGTH = 10 + + +def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict: + """ + Prepare file dictionary for message end stream response. + + :param message_file: MessageFile instance + :param upload_files_map: Dictionary mapping upload_file_id to UploadFile + :return: Dictionary containing file information + """ + upload_file = None + if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id: + upload_file = upload_files_map.get(message_file.upload_file_id) + + url = None + filename = "file" + mime_type = "application/octet-stream" + size = 0 + extension = "" + + if message_file.transfer_method == FileTransferMethod.REMOTE_URL: + url = message_file.url + if message_file.url: + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: + if upload_file: + url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) + filename = upload_file.name + mime_type = upload_file.mime_type or "application/octet-stream" + size = upload_file.size or 0 + extension = f".{upload_file.extension}" if upload_file.extension else "" + elif message_file.upload_file_id: + url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) + elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: + if message_file.url.startswith(("http://", "https://")): + url = message_file.url + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + else: + url_parts = message_file.url.split("/") + if url_parts: + file_part = url_parts[-1].split("?")[0] + if "." in file_part: + tool_file_id, ext = file_part.rsplit(".", 1) + extension = f".{ext}" + if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + extension = ".bin" + else: + tool_file_id = file_part + extension = ".bin" + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) + filename = file_part + + transfer_method_value = message_file.transfer_method.value + remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" + return { + "related_id": message_file.id, + "extension": extension, + "filename": filename, + "size": size, + "mime_type": mime_type, + "transfer_method": transfer_method_value, + "type": message_file.type, + "url": url or "", + "upload_file_id": message_file.upload_file_id or message_file.id, + "remote_url": remote_url, + } diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index bae39dc8c7..4b47777f0b 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -59,8 +59,6 @@ class DatasourcePluginProviderController(ABC): :param credentials: the credentials of the tool """ credentials_schema = dict[str, ProviderConfig]() - if credentials_schema is None: - return for credential in self.entity.credentials_schema: credentials_schema[credential.name] = credential diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 7624586367..0e00e90520 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -7,7 +7,7 @@ import uuid from collections import deque from collections.abc import Sequence from datetime import datetime -from typing import Final, cast +from typing import Final from urllib.parse import urljoin import httpx @@ -201,7 +201,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int: raise ValueError("UUID cannot be None") try: uuid_obj = uuid.UUID(uuid_v4) - return cast(int, uuid_obj.int) + return uuid_obj.int except ValueError as e: raise ValueError(f"Invalid UUID input: {uuid_v4}") from e diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index eeae489c68..eab51fd9f8 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -1,3 +1,4 @@ +import hashlib import logging import os import uuid @@ -46,6 +47,22 @@ def wrap_metadata(metadata, **kwargs): return metadata +def _seed_to_uuid4(seed: str) -> str: + """Derive a deterministic UUID4-formatted string from an arbitrary seed. + + uuid4_to_uuid7 requires a valid UUID v4 string, but some Dify identifiers + are not UUIDs (e.g. a workflow_run_id with a "-root" suffix appended to + distinguish the root span from the trace). This helper hashes the seed + with MD5 and patches the version/variant bits so the result satisfies the + UUID v4 contract. + """ + raw = hashlib.md5(seed.encode()).digest() + ba = bytearray(raw) + ba[6] = (ba[6] & 0x0F) | 0x40 # version 4 + ba[8] = (ba[8] & 0x3F) | 0x80 # variant 1 + return str(uuid.UUID(bytes=bytes(ba))) + + def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None): """Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most messages and objects. The type-hints of BaseTraceInfo indicates that @@ -95,60 +112,52 @@ class OpikDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id - opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) workflow_metadata = wrap_metadata( trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id ) - root_span_id = None if trace_info.message_id: dify_trace_id = trace_info.trace_id or trace_info.message_id - opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) - - trace_data = { - "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE, - "start_time": trace_info.start_time, - "end_time": trace_info.end_time, - "metadata": workflow_metadata, - "input": wrap_dict("input", trace_info.workflow_run_inputs), - "output": wrap_dict("output", trace_info.workflow_run_outputs), - "thread_id": trace_info.conversation_id, - "tags": ["message", "workflow"], - "project_name": self.project, - } - self.add_trace(trace_data) - - root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) - span_data = { - "id": root_span_id, - "parent_span_id": None, - "trace_id": opik_trace_id, - "name": TraceTaskName.WORKFLOW_TRACE, - "input": wrap_dict("input", trace_info.workflow_run_inputs), - "output": wrap_dict("output", trace_info.workflow_run_outputs), - "start_time": trace_info.start_time, - "end_time": trace_info.end_time, - "metadata": workflow_metadata, - "tags": ["workflow"], - "project_name": self.project, - } - self.add_span(span_data) + trace_name = TraceTaskName.MESSAGE_TRACE + trace_tags = ["message", "workflow"] + root_span_seed = trace_info.workflow_run_id else: - trace_data = { - "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE, - "start_time": trace_info.start_time, - "end_time": trace_info.end_time, - "metadata": workflow_metadata, - "input": wrap_dict("input", trace_info.workflow_run_inputs), - "output": wrap_dict("output", trace_info.workflow_run_outputs), - "thread_id": trace_info.conversation_id, - "tags": ["workflow"], - "project_name": self.project, - } - self.add_trace(trace_data) + dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id + trace_name = TraceTaskName.WORKFLOW_TRACE + trace_tags = ["workflow"] + root_span_seed = _seed_to_uuid4(trace_info.workflow_run_id + "-root") + + opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) + + trace_data = { + "id": opik_trace_id, + "name": trace_name, + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": workflow_metadata, + "input": wrap_dict("input", trace_info.workflow_run_inputs), + "output": wrap_dict("output", trace_info.workflow_run_outputs), + "thread_id": trace_info.conversation_id, + "tags": trace_tags, + "project_name": self.project, + } + self.add_trace(trace_data) + + root_span_id = prepare_opik_uuid(trace_info.start_time, root_span_seed) + span_data = { + "id": root_span_id, + "parent_span_id": None, + "trace_id": opik_trace_id, + "name": TraceTaskName.WORKFLOW_TRACE, + "input": wrap_dict("input", trace_info.workflow_run_inputs), + "output": wrap_dict("output", trace_info.workflow_run_outputs), + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": workflow_metadata, + "tags": ["workflow"], + "project_name": self.project, + } + self.add_span(span_data) # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) @@ -231,15 +240,13 @@ class OpikDataTrace(BaseTraceInstance): else: run_type = "tool" - parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id - if not total_tokens: total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 span_data = { "trace_id": opik_trace_id, "id": prepare_opik_uuid(created_at, node_execution_id), - "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), + "parent_span_id": root_span_id, "name": node_name, "type": run_type, "start_time": created_at, diff --git a/api/core/ops/tencent_trace/utils.py b/api/core/ops/tencent_trace/utils.py index 96087951ab..678287ae1d 100644 --- a/api/core/ops/tencent_trace/utils.py +++ b/api/core/ops/tencent_trace/utils.py @@ -6,7 +6,6 @@ import hashlib import random import uuid from datetime import datetime -from typing import cast from opentelemetry.trace import Link, SpanContext, TraceFlags @@ -23,7 +22,7 @@ class TencentTraceUtils: uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4() except Exception as e: raise ValueError(f"Invalid UUID input: {e}") - return cast(int, uuid_obj.int) + return uuid_obj.int @staticmethod def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int: @@ -52,9 +51,9 @@ class TencentTraceUtils: @staticmethod def create_link(trace_id_str: str) -> Link: try: - trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int) + trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else uuid.UUID(trace_id_str).int except (ValueError, TypeError): - trace_id = cast(int, uuid.uuid4().int) + trace_id = uuid.uuid4().int span_context = SpanContext( trace_id=trace_id, diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 3c5df2b905..60d08b26c9 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Union +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if workflow is None: raise ValueError("unexpected app type") - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app.app_model_config if app_model_config is None: raise ValueError("unexpected app type") - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 2dc540e6a8..416e0f6b4d 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -157,6 +157,7 @@ class PluginInstallTaskPluginStatus(BaseModel): message: str = Field(description="The message of the install task.") icon: str = Field(description="The icon of the plugin.") labels: I18nObject = Field(description="The labels of the plugin.") + source: str | None = Field(default=None, description="The installation source of the plugin") class PluginInstallTask(BasePluginEntity): diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index de1572410c..cbc846f716 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -65,7 +65,7 @@ class ChromaVector(BaseVector): self._client.get_or_create_collection(collection_name) redis_client.set(collection_exist_cache_key, 1, ex=3600) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] @@ -73,6 +73,7 @@ class ChromaVector(BaseVector): collection = self._client.get_or_create_collection(self._collection_name) # FIXME: chromadb using numpy array, fix the type error later collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore + return uuids def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 91bb71bfa6..8e8120fc10 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector): logger.warning("Failed to create inverted index: %s", e) # Continue without inverted index - full-text search will fall back to LIKE - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: """Add documents with embeddings to the collection.""" if not documents: - return + return [] batch_size = self._config.batch_size total_batches = (len(documents) + batch_size - 1) // batch_size + added_ids = [] for i in range(0, len(documents), batch_size): batch_docs = documents[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] + batch_doc_ids = [] + for doc in batch_docs: + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} + batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))) + added_ids.extend(batch_doc_ids) # Execute batch insert through write queue - self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) + self._execute_write( + self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches + ) + + return added_ids def _insert_batch( self, batch_docs: list[Document], batch_embeddings: list[list[float]], + batch_doc_ids: list[str], batch_index: int, batch_size: int, total_batches: int, @@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector): data_rows = [] vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 - for doc, embedding in zip(batch_docs, batch_embeddings): + for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids): # Optimized: minimal checks for common case, fallback for edge cases - metadata = doc.metadata or {} - - if not isinstance(metadata, dict): - metadata = {} - - doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} # Fast path for JSON serialization try: diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 6d28ce25bc..449be6a448 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -74,7 +74,8 @@ class ExtractProcessor: else: suffix = "" # https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521 - file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}" + # Generate a temporary filename under the created temp_dir and ensure the directory exists + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model") if return_text: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 1ddbfc5864..d6b6ca35be 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -204,26 +204,61 @@ class WordExtractor(BaseExtractor): return " ".join(unique_content) def _parse_cell_paragraph(self, paragraph, image_map): - paragraph_content = [] - for run in paragraph.runs: - if run.element.xpath(".//a:blip"): - for blip in run.element.xpath(".//a:blip"): - image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") - if not image_id: - continue - rel = paragraph.part.rels.get(image_id) - if rel is None: - continue - # For external images, use image_id as key; for internal, use target_part - if rel.is_external: - if image_id in image_map: - paragraph_content.append(image_map[image_id]) - else: - image_part = rel.target_part - if image_part in image_map: - paragraph_content.append(image_map[image_part]) - else: - paragraph_content.append(run.text) + paragraph_content: list[str] = [] + + for child in paragraph._element: + tag = child.tag + if tag == qn("w:hyperlink"): + # Note: w:hyperlink elements may also use w:anchor for internal bookmarks. + # This extractor intentionally only converts external links (HTTP/mailto, etc.) + # that are backed by a relationship id (r:id) with rel.is_external == True. + # Hyperlinks without such an external rel (including anchor-only bookmarks) + # are left as plain text link_text. + r_id = child.get(qn("r:id")) + link_text_parts: list[str] = [] + for run_elem in child.findall(qn("w:r")): + run = Run(run_elem, paragraph) + if run.text: + link_text_parts.append(run.text) + link_text = "".join(link_text_parts).strip() + if r_id: + try: + rel = paragraph.part.rels.get(r_id) + if rel: + target_ref = getattr(rel, "target_ref", None) + if target_ref: + parsed_target = urlparse(str(target_ref)) + if rel.is_external or parsed_target.scheme in ("http", "https", "mailto"): + display_text = link_text or str(target_ref) + link_text = f"[{display_text}]({target_ref})" + except Exception: + logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id) + if link_text: + paragraph_content.append(link_text) + + elif tag == qn("w:r"): + run = Run(child, paragraph) + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): + image_id = blip.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + ) + if not image_id: + continue + rel = paragraph.part.rels.get(image_id) + if rel is None: + continue + if rel.is_external: + if image_id in image_map: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map: + paragraph_content.append(image_map[image_part]) + else: + if run.text: + paragraph_content.append(run.text) + return "".join(paragraph_content).strip() def parse_docx(self, docx_path): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index b56ff9edef..8243170c62 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -57,7 +57,7 @@ from core.rag.retrieval.template_prompts import ( from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -127,11 +127,12 @@ class DatasetRetrieval: metadata_filter_document_ids, metadata_condition = None, None if request.metadata_filtering_mode != "disabled": - # Convert workflow layer types to app_config layer types - if not request.metadata_model_config: - raise ValueError("metadata_model_config is required for this method") + app_metadata_model_config = ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}) + if request.metadata_filtering_mode == "automatic": + if not request.metadata_model_config: + raise ValueError("metadata_model_config is required for this method") - app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump()) + app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump()) app_metadata_filtering_conditions = None if request.metadata_filtering_conditions is not None: diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 649e2f7358..770df8b050 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -194,6 +194,13 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # Create a new database session with self._session_factory() as session: + existing_model = session.get(WorkflowRun, db_model.id) + if existing_model: + if existing_model.tenant_id != self._tenant_id: + raise ValueError("Unauthorized access to workflow run") + # Preserve the original start time for pause/resume flows. + db_model.created_at = existing_model.created_at + # SQLAlchemy merge intelligently handles both insert and update operations # based on the presence of the primary key session.merge(db_model) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 83e4e53418..f6eccc734b 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,28 +10,19 @@ from typing import Union from uuid import uuid4 import httpx -from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from extensions.ext_database import db as global_db +from dify_graph.file.models import ToolFile as ToolFilePydanticModel from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile logger = logging.getLogger(__name__) -from sqlalchemy.engine import Engine - class ToolFileManager: - _engine: Engine - - def __init__(self, engine: Engine | None = None): - if engine is None: - engine = global_db.engine - self._engine = engine - @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -89,7 +80,7 @@ class ToolFileManager: filepath = f"tools/{tenant_id}/{unique_filename}" storage.save(filepath, file_binary) - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, @@ -132,7 +123,7 @@ class ToolFileManager: filename = f"{unique_name}{extension}" filepath = f"tools/{tenant_id}/{filename}" storage.save(filepath, blob) - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, @@ -157,7 +148,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file: ToolFile | None = ( session.query(ToolFile) .where( @@ -181,7 +172,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: message_file: MessageFile | None = ( session.query(MessageFile) .where( @@ -217,7 +208,9 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: + def get_file_generator_by_tool_file_id( + self, tool_file_id: str + ) -> tuple[Generator | None, ToolFilePydanticModel | None]: """ get file binary @@ -225,7 +218,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file: ToolFile | None = ( session.query(ToolFile) .where( @@ -239,7 +232,7 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, tool_file + return stream, ToolFilePydanticModel.model_validate(tool_file) # init tool_file_parser diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index d73012375d..aef8b3f779 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -37,6 +37,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = { VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN, VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE, VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES, + VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT, } diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 4cbee08a65..8c6b1dedee 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -50,6 +50,7 @@ from dify_graph.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, ) from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -250,6 +251,7 @@ class DifyNodeFactory(NodeFactory): model_factory=self._llm_model_factory, model_instance=model_instance, memory=memory, + http_client=self._http_request_http_client, ) if node_type == NodeType.DATASOURCE: @@ -292,6 +294,7 @@ class DifyNodeFactory(NodeFactory): model_factory=self._llm_model_factory, model_instance=model_instance, memory=memory, + http_client=self._http_request_http_client, ) if node_type == NodeType.PARAMETER_EXTRACTOR: @@ -308,6 +311,15 @@ class DifyNodeFactory(NodeFactory): memory=memory, ) + if node_type == NodeType.TOOL: + return ToolNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + tool_file_manager_factory=self._http_request_tool_file_manager_factory(), + ) + return node_class( id=node_id, config=node_config, diff --git a/api/dify_graph/file/models.py b/api/dify_graph/file/models.py index db12d4f57a..dcba00978e 100644 --- a/api/dify_graph/file/models.py +++ b/api/dify_graph/file/models.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any +from uuid import UUID, uuid4 from pydantic import BaseModel, Field, model_validator @@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel): number_limits: int = 0 +class ToolFile(BaseModel): + id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") + user_id: UUID = Field(..., description="ID of the user who owns this file") + tenant_id: UUID = Field(..., description="ID of the tenant/organization") + conversation_id: UUID | None = Field(None, description="ID of the associated conversation") + file_key: str = Field(..., max_length=255, description="Storage key for the file") + mimetype: str = Field(..., max_length=255, description="MIME type of the file") + original_url: str | None = Field( + None, max_length=2048, description="Original URL if file was fetched from external source" + ) + name: str = Field(default="", max_length=255, description="Display name of the file") + size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") + + class Config: + from_attributes = True # Enable ORM mode for SQLAlchemy compatibility + populate_by_name = True + + class File(BaseModel): # NOTE: dify_model_identity is a special identifier used to distinguish between # new and old data formats during serialization and deserialization. diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py index 5945e57926..c26b18aac9 100644 --- a/api/dify_graph/nodes/document_extractor/node.py +++ b/api/dify_graph/nodes/document_extractor/node.py @@ -4,6 +4,7 @@ import json import logging import os import tempfile +import zipfile from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any @@ -82,8 +83,18 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): value = variable.value inputs = {"variable_selector": variable_selector} + if isinstance(value, list): + value = list(filter(lambda x: x, value)) process_data = {"documents": value if isinstance(value, list) else [value]} + if not value: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": ArrayStringSegment(value=[])}, + ) + try: if isinstance(value, list): extracted_text_list = [ @@ -111,6 +122,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): else: raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") except DocumentExtractorError as e: + logger.warning(e, exc_info=True) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), @@ -385,6 +397,32 @@ def parser_docx_part(block, doc: Document, content_items, i): content_items.append((i, "table", Table(block, doc))) +def _normalize_docx_zip(file_content: bytes) -> bytes: + """ + Some DOCX files (e.g. exported by Evernote on Windows) are malformed: + ZIP entry names use backslash (\\) as path separator instead of the forward + slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry + "word\\document.xml" is never found when python-docx looks for + "word/document.xml", which triggers a KeyError about a missing relationship. + + This function rewrites the ZIP in-memory, normalizing all entry names to + use forward slashes without touching any actual document content. + """ + try: + with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin: + out_buf = io.BytesIO() + with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout: + for item in zin.infolist(): + data = zin.read(item.filename) + # Normalize backslash path separators to forward slash + item.filename = item.filename.replace("\\", "/") + zout.writestr(item, data) + return out_buf.getvalue() + except zipfile.BadZipFile: + # Not a valid zip — return as-is and let python-docx report the real error + return file_content + + def _extract_text_from_docx(file_content: bytes) -> str: """ Extract text from a DOCX file. @@ -392,7 +430,15 @@ def _extract_text_from_docx(file_content: bytes) -> str: """ try: doc_file = io.BytesIO(file_content) - doc = docx.Document(doc_file) + try: + doc = docx.Document(doc_file) + except Exception as e: + logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e) + # Some DOCX files exported by tools like Evernote on Windows use + # backslash path separators in ZIP entries and/or single-quoted XML + # attributes, both of which break python-docx on Linux. Normalize and retry. + file_content = _normalize_docx_zip(file_content) + doc = docx.Document(io.BytesIO(file_content)) text = [] # Keep track of paragraph and table positions diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d84dda42d6..c67e14ce17 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source from dify_graph.variables import ( ArrayFileSegment, @@ -23,7 +22,11 @@ from dify_graph.variables import ( ) from dify_graph.variables.segments import ArrayObjectSegment -from .entities import KnowledgeRetrievalNodeData +from .entities import ( + Condition, + KnowledgeRetrievalNodeData, + MetadataFilteringCondition, +) from .exc import ( KnowledgeRetrievalNodeError, RateLimitExceededError, @@ -43,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD # Output variable for file _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - def __init__( self, id: str, @@ -52,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", rag_retrieval: RAGRetrievalProtocol, - *, - llm_file_saver: LLMFileSaver | None = None, ): super().__init__( id=id, @@ -65,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD self._file_outputs = [] self._rag_retrieval = rag_retrieval - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - ) - self._llm_file_saver = llm_file_saver - @classmethod def version(cls): return "1" @@ -116,7 +107,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD try: results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) - outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])} + outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -171,6 +162,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if node_data.metadata_filtering_mode is not None: metadata_filtering_mode = node_data.metadata_filtering_mode + resolved_metadata_conditions = ( + self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions) + if node_data.metadata_filtering_conditions + else None + ) + if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: @@ -189,7 +186,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD model_mode=model.mode, model_name=model.name, metadata_model_config=node_data.metadata_model_config, - metadata_filtering_conditions=node_data.metadata_filtering_conditions, + metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, query=query, ) @@ -247,7 +244,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD weights=weights, reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_model_config=node_data.metadata_model_config, - metadata_filtering_conditions=node_data.metadata_filtering_conditions, + metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) @@ -256,6 +253,48 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD usage = self._rag_retrieval.llm_usage return retrieval_resource_list, usage + def _resolve_metadata_filtering_conditions( + self, conditions: MetadataFilteringCondition + ) -> MetadataFilteringCondition: + if conditions.conditions is None: + return MetadataFilteringCondition( + logical_operator=conditions.logical_operator, + conditions=None, + ) + + variable_pool = self.graph_runtime_state.variable_pool + resolved_conditions: list[Condition] = [] + for cond in conditions.conditions or []: + value = cond.value + if isinstance(value, str): + segment_group = variable_pool.convert_template(value) + if len(segment_group.value) == 1: + resolved_value = segment_group.value[0].to_object() + else: + resolved_value = segment_group.text + elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): + resolved_values = [] + for v in value: # type: ignore + segment_group = variable_pool.convert_template(v) + if len(segment_group.value) == 1: + resolved_values.append(segment_group.value[0].to_object()) + else: + resolved_values.append(segment_group.text) + resolved_value = resolved_values + else: + resolved_value = value + resolved_conditions.append( + Condition( + name=cond.name, + comparison_operator=cond.comparison_operator, + value=resolved_value, + ) + ) + return MetadataFilteringCondition( + logical_operator=conditions.logical_operator or "and", + conditions=resolved_conditions, + ) + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py index b4f64f4093..50e52a3b6f 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -1,14 +1,11 @@ import mimetypes import typing as tp -from sqlalchemy import Engine - from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.helper import ssrf_proxy from core.tools.signature import sign_tool_file from core.tools.tool_file_manager import ToolFileManager from dify_graph.file import File, FileTransferMethod, FileType -from extensions.ext_database import db as global_db +from dify_graph.nodes.protocols import HttpClientProtocol class LLMFileSaver(tp.Protocol): @@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol): raise NotImplementedError() -EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] - - class FileSaverImpl(LLMFileSaver): - _engine_factory: EngineFactory _tenant_id: str _user_id: str - def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None): - if engine_factory is None: - - def _factory(): - return global_db.engine - - engine_factory = _factory - self._engine_factory = engine_factory + def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): self._user_id = user_id self._tenant_id = tenant_id + self._http_client = http_client def _get_tool_file_manager(self): - return ToolFileManager(engine=self._engine_factory()) + return ToolFileManager() def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = ssrf_proxy.get(url) + http_response = self._http_client.get(url) http_response.raise_for_status() data = http_response.content mime_type_from_header = http_response.headers.get("Content-Type") diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index c7697a0972..5e59c96cd6 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import VariablePool from dify_graph.variables import ( ArrayFileSegment, @@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]): credentials_provider: CredentialsProvider, model_factory: ModelFactory, model_instance: ModelInstance, + http_client: HttpClientProtocol, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]): llm_file_saver = FileSaverImpl( user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + http_client=http_client, ) self._llm_file_saver = llm_file_saver diff --git a/api/dify_graph/nodes/protocols.py b/api/dify_graph/nodes/protocols.py index cc007150f1..62d3bcdca1 100644 --- a/api/dify_graph/nodes/protocols.py +++ b/api/dify_graph/nodes/protocols.py @@ -1,8 +1,10 @@ +from collections.abc import Generator from typing import Any, Protocol import httpx from dify_graph.file import File +from dify_graph.file.models import ToolFile class HttpClientProtocol(Protocol): @@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol): mimetype: str, filename: str | None = None, ) -> Any: ... + + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 97535d832d..443d216186 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -28,6 +28,7 @@ from dify_graph.nodes.llm import ( ) from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.protocols import HttpClientProtocol from libs.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData @@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", model_instance: ModelInstance, + http_client: HttpClientProtocol, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): llm_file_saver = FileSaverImpl( user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + http_client=http_client, ) self._llm_file_saver = llm_file_saver diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index 57fb946559..a6e0b710f1 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -1,9 +1,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment from dify_graph.variables.variables import ArrayAnyVariable -from extensions.ext_database import db from factories import file_factory -from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import ToolNodeData @@ -36,7 +32,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.runtime import VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool class ToolNode(Node[ToolNodeData]): @@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]): node_type = NodeType.TOOL + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + tool_file_manager_factory: ToolFileManagerProtocol, + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._tool_file_manager_factory = tool_file_manager_factory + @classmethod def version(cls) -> str: return "1" @@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]): tool_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") + _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + if not tool_file: + raise ToolFileError(f"tool file {tool_file_id} not found") mapping = { "tool_file_id": tool_file_id, @@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]): assert message.meta tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"tool file {tool_file_id} not exists") + _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + if not tool_file: + raise ToolFileError(f"tool file {tool_file_id} not exists") mapping = { "tool_file_id": tool_file_id, diff --git a/api/dify_graph/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py index a2b1af99bb..e3ef6a2897 100644 --- a/api/dify_graph/runtime/variable_pool.py +++ b/api/dify_graph/runtime/variable_pool.py @@ -65,9 +65,15 @@ class VariablePool(BaseModel): # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool + # Add conversation variables to the variable pool. When restoring from a serialized + # snapshot, `variable_dictionary` already carries the latest runtime values. + # In that case, keep existing entries instead of overwriting them with the + # bootstrap list. for var in self.conversation_variables: - self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) + if self._has(selector): + continue + self.add(selector, var) # Add rag pipeline variables to the variable pool if self.rag_pipeline_variables: rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 1a675b3338..6b904b7d0d 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 69959acd19..b70c2183d2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from sqlalchemy import select from events.app_event import app_model_config_was_updated @@ -54,9 +56,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s continue tool_type = list(tool.keys())[0] - tool_config = list(tool.values())[0] + tool_config = cast(dict[str, Any], list(tool.values())[0]) if tool_type == "dataset": - dataset_ids.add(tool_config.get("id")) + dataset_id = tool_config.get("id") + if isinstance(dataset_id, str): + dataset_ids.add(dataset_id) # get dataset from dataset_configs dataset_configs = app_model_config.dataset_configs_dict diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 46885761a1..fe95cc5816 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -13,6 +13,7 @@ def init_app(app: DifyApp): convert_to_agent_apps, create_tenant, delete_archived_workflow_runs, + export_app_messages, extract_plugins, extract_unique_plugins, file_usage, @@ -66,6 +67,7 @@ def init_app(app: DifyApp): restore_workflow_runs, clean_workflow_runs, clean_expired_messages, + export_app_messages, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/extensions/otel/celery_sqlcommenter.py b/api/extensions/otel/celery_sqlcommenter.py new file mode 100644 index 0000000000..8abb1ce15a --- /dev/null +++ b/api/extensions/otel/celery_sqlcommenter.py @@ -0,0 +1,114 @@ +""" +Celery SQL comment context for OpenTelemetry SQLCommenter. + +Injects Celery-specific metadata (framework, task_name, traceparent, celery_retries, +routing_key) into SQL comments for queries executed by Celery workers. This improves +trace-to-SQL correlation and debugging in production. + +Uses the OpenTelemetry context key SQLCOMMENTER_ORM_TAGS_AND_VALUES, which is read +by opentelemetry.instrumentation.sqlcommenter_utils._add_framework_tags() when the +SQLAlchemy instrumentor appends comments to SQL statements. +""" + +import logging +from typing import Any + +from celery.signals import task_postrun, task_prerun +from opentelemetry import context +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +logger = logging.getLogger(__name__) +_TRACE_PROPAGATOR = TraceContextTextMapPropagator() + +_SQLCOMMENTER_CONTEXT_KEY = "SQLCOMMENTER_ORM_TAGS_AND_VALUES" +_TOKEN_ATTR = "_dify_sqlcommenter_context_token" + + +def _build_celery_sqlcommenter_tags(task: Any) -> dict[str, str | int]: + """Build SQL commenter tags from the current Celery task and OpenTelemetry context.""" + tags: dict[str, str | int] = {} + + try: + tags["framework"] = f"celery:{_get_celery_version()}" + except Exception: + tags["framework"] = "celery:unknown" + + if task and getattr(task, "name", None): + tags["task_name"] = str(task.name) + + traceparent = _get_traceparent() + if traceparent: + tags["traceparent"] = traceparent + + if task and hasattr(task, "request"): + request = task.request + retries = getattr(request, "retries", None) + if retries is not None and retries > 0: + tags["celery_retries"] = int(retries) + + delivery_info = getattr(request, "delivery_info", None) or {} + if isinstance(delivery_info, dict): + routing_key = delivery_info.get("routing_key") + if routing_key: + tags["routing_key"] = str(routing_key) + + return tags + + +def _get_celery_version() -> str: + import celery + + return getattr(celery, "__version__", "unknown") + + +def _get_traceparent() -> str | None: + """Extract traceparent from the current OpenTelemetry context.""" + carrier: dict[str, str] = {} + _TRACE_PROPAGATOR.inject(carrier) + return carrier.get("traceparent") + + +def _on_task_prerun(*args: object, **kwargs: object) -> None: + task = kwargs.get("task") + if not task: + return + + tags = _build_celery_sqlcommenter_tags(task) + if not tags: + return + + current = context.get_current() + new_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, tags, current) + token = context.attach(new_ctx) + setattr(task, _TOKEN_ATTR, token) + + +def _on_task_postrun(*args: object, **kwargs: object) -> None: + task = kwargs.get("task") + if not task: + return + + token = getattr(task, _TOKEN_ATTR, None) + if token is None: + return + + try: + context.detach(token) + except Exception: + logger.debug("Failed to detach SQL commenter context", exc_info=True) + finally: + try: + delattr(task, _TOKEN_ATTR) + except AttributeError: + pass + + +def setup_celery_sqlcommenter() -> None: + """ + Connect Celery task_prerun and task_postrun handlers to inject SQL comment + context for worker queries. Call this from init_celery_worker after + CeleryInstrumentor().instrument() so our handlers run after the OTEL + instrumentor's and the trace context is already attached. + """ + task_prerun.connect(_on_task_prerun, weak=False) + task_postrun.connect(_on_task_postrun, weak=False) diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index a7181d2683..a9ff0eed22 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -67,11 +67,14 @@ def init_celery_worker(*args, **kwargs): from opentelemetry.metrics import get_meter_provider from opentelemetry.trace import get_tracer_provider + from extensions.otel.celery_sqlcommenter import setup_celery_sqlcommenter + tracer_provider = get_tracer_provider() metric_provider = get_meter_provider() if dify_config.DEBUG: logger.info("Initializing OpenTelemetry for Celery worker") CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() + setup_celery_sqlcommenter() def is_instrument_flag_enabled() -> bool: diff --git a/api/migrations/env.py b/api/migrations/env.py index 66a4614e80..3b1fa7bb89 100644 --- a/api/migrations/env.py +++ b/api/migrations/env.py @@ -66,6 +66,7 @@ def run_migrations_offline(): context.configure( url=url, target_metadata=get_metadata(), literal_binds=True ) + logger.info("Generating offline migration SQL with url: %s", url) with context.begin_transaction(): context.run_migrations() diff --git a/api/models/model.py b/api/models/model.py index 2bf80edb80..ed0614c195 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 import sqlalchemy as sa @@ -15,6 +15,7 @@ from flask import request from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column +from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS @@ -36,6 +37,259 @@ if TYPE_CHECKING: from .workflow import Workflow +# --- TypedDict definitions for structured dict return types --- + + +class EnabledConfig(TypedDict): + enabled: bool + + +class EmbeddingModelInfo(TypedDict): + embedding_provider_name: str + embedding_model_name: str + + +class AnnotationReplyDisabledConfig(TypedDict): + enabled: Literal[False] + + +class AnnotationReplyEnabledConfig(TypedDict): + id: str + enabled: Literal[True] + score_threshold: float + embedding_model: EmbeddingModelInfo + + +AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig + + +class SensitiveWordAvoidanceConfig(TypedDict): + enabled: bool + type: str + config: dict[str, Any] + + +class AgentToolConfig(TypedDict): + provider_type: str + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] + plugin_unique_identifier: NotRequired[str | None] + credential_id: NotRequired[str | None] + + +class AgentModeConfig(TypedDict): + enabled: bool + strategy: str | None + tools: list[AgentToolConfig | dict[str, Any]] + prompt: str | None + + +class ImageUploadConfig(TypedDict): + enabled: bool + number_limits: int + detail: str + transfer_methods: list[str] + + +class FileUploadConfig(TypedDict): + image: ImageUploadConfig + + +class DeletedToolInfo(TypedDict): + type: str + tool_name: str + provider_id: str + + +class ExternalDataToolConfig(TypedDict): + enabled: bool + variable: str + type: str + config: dict[str, Any] + + +class UserInputFormItemConfig(TypedDict): + variable: str + label: str + description: NotRequired[str] + required: NotRequired[bool] + max_length: NotRequired[int] + options: NotRequired[list[str]] + default: NotRequired[str] + type: NotRequired[str] + config: NotRequired[dict[str, Any]] + + +# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig} +UserInputFormItem = dict[str, UserInputFormItemConfig] + + +class DatasetConfigs(TypedDict): + retrieval_model: str + datasets: NotRequired[dict[str, Any]] + top_k: NotRequired[int] + score_threshold: NotRequired[float] + score_threshold_enabled: NotRequired[bool] + reranking_model: NotRequired[dict[str, Any] | None] + weights: NotRequired[dict[str, Any] | None] + reranking_enabled: NotRequired[bool] + reranking_mode: NotRequired[str] + metadata_filtering_mode: NotRequired[str] + metadata_model_config: NotRequired[dict[str, Any] | None] + metadata_filtering_conditions: NotRequired[dict[str, Any] | None] + + +class ChatPromptMessage(TypedDict): + text: str + role: str + + +class ChatPromptConfig(TypedDict, total=False): + prompt: list[ChatPromptMessage] + + +class CompletionPromptText(TypedDict): + text: str + + +class ConversationHistoriesRole(TypedDict): + user_prefix: str + assistant_prefix: str + + +class CompletionPromptConfig(TypedDict): + prompt: CompletionPromptText + conversation_histories_role: NotRequired[ConversationHistoriesRole] + + +class ModelConfig(TypedDict): + provider: str + name: str + mode: str + completion_params: NotRequired[dict[str, Any]] + + +class AppModelConfigDict(TypedDict): + opening_statement: str | None + suggested_questions: list[str] + suggested_questions_after_answer: EnabledConfig + speech_to_text: EnabledConfig + text_to_speech: EnabledConfig + retriever_resource: EnabledConfig + annotation_reply: AnnotationReplyConfig + more_like_this: EnabledConfig + sensitive_word_avoidance: SensitiveWordAvoidanceConfig + external_data_tools: list[ExternalDataToolConfig] + model: ModelConfig + user_input_form: list[UserInputFormItem] + dataset_query_variable: str | None + pre_prompt: str | None + agent_mode: AgentModeConfig + prompt_type: str + chat_prompt_config: ChatPromptConfig + completion_prompt_config: CompletionPromptConfig + dataset_configs: DatasetConfigs + file_upload: FileUploadConfig + # Added dynamically in Conversation.model_config + model_id: NotRequired[str | None] + provider: NotRequired[str | None] + + +class ConversationDict(TypedDict): + id: str + app_id: str + app_model_config_id: str | None + model_provider: str | None + override_model_configs: str | None + model_id: str | None + mode: str + name: str + summary: str | None + inputs: dict[str, Any] + introduction: str | None + system_instruction: str | None + system_instruction_tokens: int + status: str + invoke_from: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + read_at: datetime | None + read_account_id: str | None + dialogue_count: int + created_at: datetime + updated_at: datetime + + +class MessageDict(TypedDict): + id: str + app_id: str + conversation_id: str + model_id: str | None + inputs: dict[str, Any] + query: str + total_price: Decimal | None + message: dict[str, Any] + answer: str + status: str + error: str | None + message_metadata: dict[str, Any] + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + agent_based: bool + workflow_run_id: str | None + + +class MessageFeedbackDict(TypedDict): + id: str + app_id: str + conversation_id: str + message_id: str + rating: str + content: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + + +class MessageFileInfo(TypedDict, total=False): + belongs_to: str | None + upload_file_id: str | None + id: str + tenant_id: str + type: str + transfer_method: str + remote_url: str | None + related_id: str | None + filename: str | None + extension: str | None + mime_type: str | None + size: int + dify_model_identity: str + url: str | None + + +class ExtraContentDict(TypedDict, total=False): + type: str + workflow_run_id: str + + +class TraceAppConfigDict(TypedDict): + id: str + app_id: str + tracing_provider: str | None + tracing_config: dict[str, Any] + is_active: bool + created_at: str | None + updated_at: str | None + + class DifySetup(TypeBase): __tablename__ = "dify_setups" __table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -176,7 +430,7 @@ class App(Base): return str(self.mode) @property - def deleted_tools(self) -> list[dict[str, str]]: + def deleted_tools(self) -> list[DeletedToolInfo]: from core.tools.tool_manager import ToolManager, ToolProviderType from services.plugin.plugin_service import PluginService @@ -257,7 +511,7 @@ class App(Base): provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) } - deleted_tools: list[dict[str, str]] = [] + deleted_tools: list[DeletedToolInfo] = [] for tool in tools: keys = list(tool.keys()) @@ -364,35 +618,38 @@ class AppModelConfig(TypeBase): return app @property - def model_dict(self) -> dict[str, Any]: - return json.loads(self.model) if self.model else {} + def model_dict(self) -> ModelConfig: + return cast(ModelConfig, json.loads(self.model) if self.model else {}) @property def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] @property - def suggested_questions_after_answer_dict(self) -> dict[str, Any]: - return ( + def suggested_questions_after_answer_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer - else {"enabled": False} + else {"enabled": False}, ) @property - def speech_to_text_dict(self) -> dict[str, Any]: - return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} + def speech_to_text_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) @property - def text_to_speech_dict(self) -> dict[str, Any]: - return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} + def text_to_speech_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) @property - def retriever_resource_dict(self) -> dict[str, Any]: - return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + def retriever_resource_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + ) @property - def annotation_reply_dict(self) -> dict[str, Any]: + def annotation_reply_dict(self) -> AnnotationReplyConfig: annotation_setting = ( db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() ) @@ -415,56 +672,62 @@ class AppModelConfig(TypeBase): return {"enabled": False} @property - def more_like_this_dict(self) -> dict[str, Any]: - return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} + def more_like_this_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) @property - def sensitive_word_avoidance_dict(self) -> dict[str, Any]: - return ( + def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: + return cast( + SensitiveWordAvoidanceConfig, json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance - else {"enabled": False, "type": "", "configs": []} + else {"enabled": False, "type": "", "config": {}}, ) @property - def external_data_tools_list(self) -> list[dict[str, Any]]: + def external_data_tools_list(self) -> list[ExternalDataToolConfig]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> list[dict[str, Any]]: + def user_input_form_list(self) -> list[UserInputFormItem]: return json.loads(self.user_input_form) if self.user_input_form else [] @property - def agent_mode_dict(self) -> dict[str, Any]: - return ( + def agent_mode_dict(self) -> AgentModeConfig: + return cast( + AgentModeConfig, json.loads(self.agent_mode) if self.agent_mode - else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + else {"enabled": False, "strategy": None, "tools": [], "prompt": None}, ) @property - def chat_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} + def chat_prompt_config_dict(self) -> ChatPromptConfig: + return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}) @property - def completion_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} + def completion_prompt_config_dict(self) -> CompletionPromptConfig: + return cast( + CompletionPromptConfig, + json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}, + ) @property - def dataset_configs_dict(self) -> dict[str, Any]: + def dataset_configs_dict(self) -> DatasetConfigs: if self.dataset_configs: - dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) + dataset_configs = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: - return dataset_configs + return cast(DatasetConfigs, dataset_configs) return { "retrieval_model": "multiple", } @property - def file_upload_dict(self) -> dict[str, Any]: - return ( + def file_upload_dict(self) -> FileUploadConfig: + return cast( + FileUploadConfig, json.loads(self.file_upload) if self.file_upload else { @@ -474,10 +737,10 @@ class AppModelConfig(TypeBase): "detail": "high", "transfer_methods": ["remote_url", "local_file"], } - } + }, ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> AppModelConfigDict: return { "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, @@ -501,36 +764,42 @@ class AppModelConfig(TypeBase): "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: Mapping[str, Any]): + def from_model_config_dict(self, model_config: AppModelConfigDict): self.opening_statement = model_config.get("opening_statement") self.suggested_questions = ( - json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None ) self.suggested_questions_after_answer = ( - json.dumps(model_config["suggested_questions_after_answer"]) + json.dumps(model_config.get("suggested_questions_after_answer")) if model_config.get("suggested_questions_after_answer") else None ) - self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None - self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None - self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.speech_to_text = ( + json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None + ) + self.text_to_speech = ( + json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None + ) + self.more_like_this = ( + json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None + ) self.sensitive_word_avoidance = ( - json.dumps(model_config["sensitive_word_avoidance"]) + json.dumps(model_config.get("sensitive_word_avoidance")) if model_config.get("sensitive_word_avoidance") else None ) self.external_data_tools = ( - json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None ) - self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None self.user_input_form = ( - json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None ) self.dataset_query_variable = model_config.get("dataset_query_variable") - self.pre_prompt = model_config["pre_prompt"] - self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.pre_prompt = model_config.get("pre_prompt") + self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None self.retriever_resource = ( - json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) self.prompt_type = model_config.get("prompt_type", "simple") self.chat_prompt_config = ( @@ -823,24 +1092,26 @@ class Conversation(Base): self._inputs = inputs @property - def model_config(self): - model_config = {} + def model_config(self) -> AppModelConfigDict: + model_config = cast(AppModelConfigDict, {}) app_model_config: AppModelConfig | None = None if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - model_config = override_model_configs + model_config = cast(AppModelConfigDict, override_model_configs) else: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) if "model" in override_model_configs: # where is app_id? - app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs) + app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict( + cast(AppModelConfigDict, override_model_configs) + ) model_config = app_model_config.to_dict() else: - model_config["configs"] = override_model_configs + model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key] else: app_model_config = ( db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() @@ -1015,7 +1286,7 @@ class Conversation(Base): def in_debug_mode(self) -> bool: return self.override_model_configs is not None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ConversationDict: return { "id": self.id, "app_id": self.app_id, @@ -1295,7 +1566,7 @@ class Message(Base): return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property - def message_files(self) -> list[dict[str, Any]]: + def message_files(self) -> list[MessageFileInfo]: from factories import file_factory message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() @@ -1350,10 +1621,13 @@ class Message(Base): ) files.append(file) - result: list[dict[str, Any]] = [ - {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} - for (file, message_file) in zip(files, message_files) - ] + result = cast( + list[MessageFileInfo], + [ + {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ], + ) db.session.commit() return result @@ -1363,7 +1637,7 @@ class Message(Base): self._extra_contents = list(contents) @property - def extra_contents(self) -> list[dict[str, Any]]: + def extra_contents(self) -> list[ExtraContentDict]: return getattr(self, "_extra_contents", []) @property @@ -1379,7 +1653,7 @@ class Message(Base): return None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageDict: return { "id": self.id, "app_id": self.app_id, @@ -1403,7 +1677,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> Message: + def from_dict(cls, data: MessageDict) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase): account = db.session.query(Account).where(Account.id == self.from_account_id).first() return account - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageFeedbackDict: return { "id": str(self.id), "app_id": str(self.app_id), @@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase): return result @property - def parameters_dict(self) -> dict[str, Any]: - return cast(dict[str, Any], json.loads(self.parameters)) + def parameters_dict(self) -> dict[str, str]: + return cast(dict[str, str], json.loads(self.parameters)) class Site(Base): @@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase): def tracing_config_str(self) -> str: return json.dumps(self.tracing_config_dict) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> TraceAppConfigDict: return { "id": self.id, "app_id": self.app_id, diff --git a/api/pyproject.toml b/api/pyproject.toml index 84b95fb226..bf786f4584 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.1.77", - "markdown~=3.5.1", + "markdown~=3.8.1", "mlflow-skinny>=3.0.0", "numpy~=1.26.4", "openpyxl~=3.1.5", @@ -113,7 +113,7 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=38.2.0", "lxml-stubs~=0.5.1", - "basedpyright~=1.31.0", + "basedpyright~=1.38.2", "ruff~=0.14.0", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", @@ -167,12 +167,12 @@ dev = [ "import-linter>=2.3", "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", - "mypy~=1.17.1", + "mypy~=1.19.1", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.54.0", + "pyrefly>=0.55.0", ] ############################################################ @@ -247,3 +247,13 @@ module = [ "extensions.logstore.repositories.logstore_api_workflow_run_repository", ] ignore_errors = true + +[tool.pyrefly] +project-includes = ["."] +project-excludes = [ + ".venv", + "migrations/", +] +python-platform = "linux" +python-version = "3.11.0" +infer-with-first-use = false diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt new file mode 100644 index 0000000000..d3b2ede745 --- /dev/null +++ b/api/pyrefly-local-excludes.txt @@ -0,0 +1,200 @@ +configs/middleware/cache/redis_pubsub_config.py +controllers/console/app/annotation.py +controllers/console/app/app.py +controllers/console/app/app_import.py +controllers/console/app/mcp_server.py +controllers/console/app/site.py +controllers/console/auth/email_register.py +controllers/console/human_input_form.py +controllers/console/init_validate.py +controllers/console/ping.py +controllers/console/setup.py +controllers/console/version.py +controllers/console/workspace/trigger_providers.py +controllers/service_api/app/annotation.py +controllers/web/workflow_events.py +core/agent/fc_agent_runner.py +core/app/apps/advanced_chat/app_generator.py +core/app/apps/advanced_chat/app_runner.py +core/app/apps/advanced_chat/generate_task_pipeline.py +core/app/apps/agent_chat/app_generator.py +core/app/apps/base_app_generate_response_converter.py +core/app/apps/base_app_generator.py +core/app/apps/chat/app_generator.py +core/app/apps/common/workflow_response_converter.py +core/app/apps/completion/app_generator.py +core/app/apps/pipeline/pipeline_generator.py +core/app/apps/pipeline/pipeline_runner.py +core/app/apps/workflow/app_generator.py +core/app/apps/workflow/app_runner.py +core/app/apps/workflow/generate_task_pipeline.py +core/app/apps/workflow_app_runner.py +core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +core/datasource/datasource_manager.py +core/external_data_tool/api/api.py +core/llm_generator/llm_generator.py +core/llm_generator/output_parser/structured_output.py +core/mcp/mcp_client.py +core/ops/aliyun_trace/data_exporter/traceclient.py +core/ops/arize_phoenix_trace/arize_phoenix_trace.py +core/ops/mlflow_trace/mlflow_trace.py +core/ops/ops_trace_manager.py +core/ops/tencent_trace/client.py +core/ops/tencent_trace/utils.py +core/plugin/backwards_invocation/base.py +core/plugin/backwards_invocation/model.py +core/prompt/utils/extract_thread_messages.py +core/rag/datasource/keyword/jieba/jieba.py +core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +core/rag/datasource/vdb/baidu/baidu_vector.py +core/rag/datasource/vdb/chroma/chroma_vector.py +core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +core/rag/datasource/vdb/couchbase/couchbase_vector.py +core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +core/rag/datasource/vdb/lindorm/lindorm_vector.py +core/rag/datasource/vdb/matrixone/matrixone_vector.py +core/rag/datasource/vdb/milvus/milvus_vector.py +core/rag/datasource/vdb/myscale/myscale_vector.py +core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +core/rag/datasource/vdb/opensearch/opensearch_vector.py +core/rag/datasource/vdb/oracle/oraclevector.py +core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +core/rag/datasource/vdb/relyt/relyt_vector.py +core/rag/datasource/vdb/tablestore/tablestore_vector.py +core/rag/datasource/vdb/tencent/tencent_vector.py +core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +core/rag/datasource/vdb/tidb_vector/tidb_vector.py +core/rag/datasource/vdb/upstash/upstash_vector.py +core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +core/rag/datasource/vdb/weaviate/weaviate_vector.py +core/rag/extractor/csv_extractor.py +core/rag/extractor/excel_extractor.py +core/rag/extractor/firecrawl/firecrawl_app.py +core/rag/extractor/firecrawl/firecrawl_web_extractor.py +core/rag/extractor/html_extractor.py +core/rag/extractor/jina_reader_extractor.py +core/rag/extractor/markdown_extractor.py +core/rag/extractor/notion_extractor.py +core/rag/extractor/pdf_extractor.py +core/rag/extractor/text_extractor.py +core/rag/extractor/unstructured/unstructured_doc_extractor.py +core/rag/extractor/unstructured/unstructured_eml_extractor.py +core/rag/extractor/unstructured/unstructured_epub_extractor.py +core/rag/extractor/unstructured/unstructured_markdown_extractor.py +core/rag/extractor/unstructured/unstructured_msg_extractor.py +core/rag/extractor/unstructured/unstructured_ppt_extractor.py +core/rag/extractor/unstructured/unstructured_pptx_extractor.py +core/rag/extractor/unstructured/unstructured_xml_extractor.py +core/rag/extractor/watercrawl/client.py +core/rag/extractor/watercrawl/extractor.py +core/rag/extractor/watercrawl/provider.py +core/rag/extractor/word_extractor.py +core/rag/index_processor/processor/paragraph_index_processor.py +core/rag/index_processor/processor/parent_child_index_processor.py +core/rag/index_processor/processor/qa_index_processor.py +core/rag/retrieval/router/multi_dataset_function_call_router.py +core/rag/summary_index/summary_index.py +core/repositories/sqlalchemy_workflow_execution_repository.py +core/repositories/sqlalchemy_workflow_node_execution_repository.py +core/tools/__base/tool.py +core/tools/mcp_tool/provider.py +core/tools/plugin_tool/provider.py +core/tools/utils/message_transformer.py +core/tools/utils/web_reader_tool.py +core/tools/workflow_as_tool/provider.py +core/trigger/debug/event_selectors.py +core/trigger/entities/entities.py +core/trigger/provider.py +core/workflow/workflow_entry.py +dify_graph/entities/workflow_execution.py +dify_graph/file/file_manager.py +dify_graph/graph_engine/error_handler.py +dify_graph/graph_engine/layers/execution_limits.py +dify_graph/nodes/agent/agent_node.py +dify_graph/nodes/base/node.py +dify_graph/nodes/code/code_node.py +dify_graph/nodes/datasource/datasource_node.py +dify_graph/nodes/document_extractor/node.py +dify_graph/nodes/human_input/human_input_node.py +dify_graph/nodes/if_else/if_else_node.py +dify_graph/nodes/iteration/iteration_node.py +dify_graph/nodes/knowledge_index/knowledge_index_node.py +dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +dify_graph/nodes/list_operator/node.py +dify_graph/nodes/llm/node.py +dify_graph/nodes/loop/loop_node.py +dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +dify_graph/nodes/question_classifier/question_classifier_node.py +dify_graph/nodes/start/start_node.py +dify_graph/nodes/template_transform/template_transform_node.py +dify_graph/nodes/tool/tool_node.py +dify_graph/nodes/trigger_plugin/trigger_event_node.py +dify_graph/nodes/trigger_schedule/trigger_schedule_node.py +dify_graph/nodes/trigger_webhook/node.py +dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +dify_graph/nodes/variable_assigner/v1/node.py +dify_graph/nodes/variable_assigner/v2/node.py +dify_graph/variables/types.py +extensions/ext_fastopenapi.py +extensions/logstore/repositories/logstore_api_workflow_run_repository.py +extensions/otel/instrumentation.py +extensions/otel/runtime.py +extensions/storage/aliyun_oss_storage.py +extensions/storage/aws_s3_storage.py +extensions/storage/azure_blob_storage.py +extensions/storage/baidu_obs_storage.py +extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +extensions/storage/clickzetta_volume/file_lifecycle.py +extensions/storage/google_cloud_storage.py +extensions/storage/huawei_obs_storage.py +extensions/storage/opendal_storage.py +extensions/storage/oracle_oci_storage.py +extensions/storage/supabase_storage.py +extensions/storage/tencent_cos_storage.py +extensions/storage/volcengine_tos_storage.py +factories/variable_factory.py +libs/external_api.py +libs/gmpy2_pkcs10aep_cipher.py +libs/helper.py +libs/login.py +libs/module_loading.py +libs/oauth.py +libs/oauth_data_source.py +models/trigger.py +models/workflow.py +repositories/sqlalchemy_api_workflow_node_execution_repository.py +repositories/sqlalchemy_api_workflow_run_repository.py +repositories/sqlalchemy_execution_extra_content_repository.py +schedule/queue_monitor_task.py +services/account_service.py +services/audio_service.py +services/auth/firecrawl/firecrawl.py +services/auth/jina.py +services/auth/jina/jina.py +services/auth/watercrawl/watercrawl.py +services/conversation_service.py +services/dataset_service.py +services/document_indexing_proxy/document_indexing_task_proxy.py +services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py +services/external_knowledge_service.py +services/plugin/plugin_migration.py +services/recommend_app/buildin/buildin_retrieval.py +services/recommend_app/database/database_retrieval.py +services/recommend_app/remote/remote_retrieval.py +services/summary_index_service.py +services/tools/tools_transform_service.py +services/trigger/trigger_provider_service.py +services/trigger/trigger_subscription_builder_service.py +services/trigger/webhook_service.py +services/workflow_draft_variable_service.py +services/workflow_event_snapshot_service.py +services/workflow_service.py +tasks/app_generate/workflow_execute_task.py +tasks/regenerate_summary_index_task.py +tasks/trigger_processing_tasks.py +tasks/workflow_cfs_scheduler/cfs_scheduler.py +tasks/workflow_execution_tasks.py diff --git a/api/pyrefly.toml b/api/pyrefly.toml deleted file mode 100644 index 01f4c5a529..0000000000 --- a/api/pyrefly.toml +++ /dev/null @@ -1,8 +0,0 @@ -project-includes = ["."] -project-excludes = [ - ".venv", - "migrations/", -] -python-platform = "linux" -python-version = "3.11.0" -infer-with-first-use = false diff --git a/api/pytest.ini b/api/pytest.ini index 4a9470fa0c..588dafe7eb 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,5 +1,6 @@ [pytest] -addopts = --cov=./api --cov-report=json +pythonpath = . +addopts = --cov=./api --cov-report=json --import-mode=importlib env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com @@ -19,7 +20,7 @@ env = GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c - HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b + HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa MOCK_SWITCH = true diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index 77d6b5a138..01642e397e 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -21,6 +21,10 @@ celery_redis = Redis( ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None, ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None, ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None, + # Add conservative socket timeouts and health checks to avoid long-lived half-open sockets + socket_timeout=5, + socket_connect_timeout=5, + health_check_interval=30, ) logger = logging.getLogger(__name__) diff --git a/api/schedule/trigger_provider_refresh_task.py b/api/schedule/trigger_provider_refresh_task.py index 3b3e478793..df5058d70a 100644 --- a/api/schedule/trigger_provider_refresh_task.py +++ b/api/schedule/trigger_provider_refresh_task.py @@ -3,6 +3,7 @@ import math import time from collections.abc import Iterable, Sequence +from celery import group from sqlalchemy import ColumnElement, and_, func, or_, select from sqlalchemy.engine.row import Row from sqlalchemy.orm import Session @@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None: lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions) acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl) - enqueued: int = 0 - for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired): - if not is_locked: - continue - trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id) - enqueued += 1 + if not any(acquired): + continue + + jobs = [ + trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id) + for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired) + if is_locked + ] + result = group(jobs).apply_async() + enqueued = len(jobs) logger.info( - "Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d", + "Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s", page + 1, pages, len(subscriptions), sum(1 for x in acquired if x), enqueued, + result, ) logger.info("Trigger refresh scan done: due=%d", total_due) diff --git a/api/schedule/workflow_schedule_task.py b/api/schedule/workflow_schedule_task.py index d68b9565ec..2fee9e467d 100644 --- a/api/schedule/workflow_schedule_task.py +++ b/api/schedule/workflow_schedule_task.py @@ -1,6 +1,6 @@ import logging -from celery import group, shared_task +from celery import current_app, group, shared_task from sqlalchemy import and_, select from sqlalchemy.orm import Session, sessionmaker @@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None: with session_factory() as session: total_dispatched = 0 - # Process in batches until we've handled all due schedules or hit the limit while True: due_schedules = _fetch_due_schedules(session) if not due_schedules: break - dispatched_count = _process_schedules(session, due_schedules) - total_dispatched += dispatched_count + with current_app.producer_or_acquire() as producer: # type: ignore + dispatched_count = _process_schedules(session, due_schedules, producer) + total_dispatched += dispatched_count - logger.debug("Batch processed: %d dispatched", dispatched_count) - - # Circuit breaker: check if we've hit the per-tick limit (if enabled) - if ( - dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0 - and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK - ): - logger.warning( - "Circuit breaker activated: reached dispatch limit (%d), will continue next tick", - dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK, - ) - break + logger.debug("Batch processed: %d dispatched", dispatched_count) + # Circuit breaker: check if we've hit the per-tick limit (if enabled) + if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched: + logger.warning( + "Circuit breaker activated: reached dispatch limit (%d), will continue next tick", + dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK, + ) + break if total_dispatched > 0: - logger.info("Total processed: %d dispatched", total_dispatched) + logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched) def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]: @@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]: return list(due_schedules) -def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int: +def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int: """Process schedules: check quota, update next run time and dispatch to Celery in parallel.""" if not schedules: return 0 @@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) if tasks_to_dispatch: job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch) - job.apply_async() + job.apply_async(producer=producer) logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 5790c8b9ec..06f4ccb90e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,6 +4,7 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum +from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig, IconType +from models.model import AppModelConfig, AppModelConfigDict, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -523,7 +524,7 @@ class AppDslService: if not app.app_model_config: app_model_config = AppModelConfig( app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(model_config) + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) app_model_config.id = str(uuid4()) app.app_model_config_id = app_model_config.id diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 6f54f90734..3bc30cb323 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,12 +1,12 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict class AppModelConfigService: @classmethod - def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode): + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict: if app_mode == AppMode.CHAT: return ChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.AGENT_CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index ce6826ef5c..aba8954f1a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import TypedDict, cast +from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination @@ -187,7 +187,7 @@ class AppService: for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool)) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( @@ -388,7 +388,7 @@ class AppService: agent_config = app_model_config.agent_mode_dict # get all tools - tools = agent_config.get("tools", []) + tools = cast(list[dict[str, Any]], agent_config.get("tools", [])) url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1b698fad17..1794ea9947 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -2,6 +2,7 @@ import io import logging import uuid from collections.abc import Generator +from typing import cast from flask import Response, stream_with_context from werkzeug.datastructures import FileStorage @@ -106,7 +107,7 @@ class AudioService: if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get("voice") + voice = cast(str | None, text_to_speech_dict.get("voice")) model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 95a50f0512..f3b2adb965 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -824,6 +824,7 @@ class DatasourceProviderService: "langgenius/firecrawl_datasource", "langgenius/notion_datasource", "langgenius/jina_datasource", + "watercrawl/watercrawl_datasource", ]: datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}") credentials = self.list_datasource_credentials( diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index 817dbd95f8..598f9692eb 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -3,6 +3,7 @@ import logging from pydantic import BaseModel +from configs import dify_config from services.enterprise.base import EnterprisePluginManagerRequest from services.errors.base import BaseServiceError @@ -28,6 +29,11 @@ class CheckCredentialPolicyComplianceRequest(BaseModel): return data +class PreUninstallPluginRequest(BaseModel): + tenant_id: str + plugin_unique_identifier: str + + class CredentialPolicyViolationError(BaseServiceError): pass @@ -55,3 +61,21 @@ class PluginManagerService: body.dify_credential_id, ret.get("result", False), ) + + @classmethod + def try_pre_uninstall_plugin(cls, body: PreUninstallPluginRequest): + try: + # the invocation must be synchronous. + EnterprisePluginManagerRequest.send_request( + "POST", + "/pre-uninstall-plugin", + json=body.model_dump(), + raise_for_status=True, + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + except Exception: + logger.exception( + "failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s", + body.tenant_id, + body.plugin_unique_identifier, + ) diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 7b43c49686..80deb37a56 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -245,5 +245,6 @@ class EmailDeliveryTestHandler: ) if token: substitutions["form_token"] = token - substitutions["form_link"] = _build_form_link(token) or "" + link = _build_form_link(token) + substitutions["form_link"] = link if link is not None else f"/form/{token}" return substitutions diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 6eed3a6b38..55a3ffde78 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -32,6 +32,10 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider import Provider, ProviderCredential from models.provider_ids import GenericProviderID +from services.enterprise.plugin_manager_service import ( + PluginManagerService, + PreUninstallPluginRequest, +) from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope @@ -519,6 +523,13 @@ class PluginService: if not plugin: return manager.uninstall(tenant_id, plugin_installation_id) + if dify_config.ENTERPRISE_ENABLED: + PluginManagerService.try_pre_uninstall_plugin( + PreUninstallPluginRequest( + tenant_id=tenant_id, + plugin_unique_identifier=plugin.plugin_unique_identifier, + ) + ) with Session(db.engine) as session, session.begin(): plugin_id = plugin.plugin_id logger.info("Deleting credentials for plugin: %s", plugin_id) diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index d0dfbc1070..cee18387b3 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -63,7 +63,12 @@ class RagPipelineTransformService: ): node = self._deal_file_extensions(node) if node.get("data", {}).get("type") == "knowledge-index": - node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) + if dataset.tenant_id != current_user.current_tenant_id: + raise ValueError("Unauthorized") + node = self._deal_knowledge_index( + knowledge_configuration, dataset, indexing_technique, retrieval_model, node + ) new_nodes.append(node) if new_nodes: graph["nodes"] = new_nodes @@ -155,14 +160,13 @@ class RagPipelineTransformService: def _deal_knowledge_index( self, + knowledge_configuration: KnowledgeConfiguration, dataset: Dataset, - doc_form: str, indexing_technique: str | None, retrieval_model: RetrievalSetting | None, node: dict, ): knowledge_configuration_dict = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict) if indexing_technique == "high_quality": knowledge_configuration.embedding_model = dataset.embedding_model diff --git a/api/services/retention/conversation/message_export_service.py b/api/services/retention/conversation/message_export_service.py new file mode 100644 index 0000000000..fbe0d2795d --- /dev/null +++ b/api/services/retention/conversation/message_export_service.py @@ -0,0 +1,304 @@ +""" +Export app messages to JSONL.GZ format. + +Outputs: conversation_id, message_id, query, answer, inputs (raw JSON), +retriever_resources (from message_metadata), feedback (user feedbacks array). + +Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1. +Does NOT touch Message.inputs / Message.user_feedback properties. +""" + +import datetime +import gzip +import json +import logging +import tempfile +from collections import defaultdict +from collections.abc import Generator, Iterable +from pathlib import Path, PurePosixPath +from typing import Any, BinaryIO, cast + +import orjson +import sqlalchemy as sa +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import select, tuple_ +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import Message, MessageFeedback + +logger = logging.getLogger(__name__) + +MAX_FILENAME_BASE_LENGTH = 1024 +FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz") + + +class AppMessageExportFeedback(BaseModel): + id: str + app_id: str + conversation_id: str + message_id: str + rating: str + content: str | None = None + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + created_at: str + updated_at: str + + model_config = ConfigDict(extra="forbid") + + +class AppMessageExportRecord(BaseModel): + conversation_id: str + message_id: str + query: str + answer: str + inputs: dict[str, Any] + retriever_resources: list[Any] = Field(default_factory=list) + feedback: list[AppMessageExportFeedback] = Field(default_factory=list) + + model_config = ConfigDict(extra="forbid") + + +class AppMessageExportStats(BaseModel): + batches: int = 0 + total_messages: int = 0 + messages_with_feedback: int = 0 + total_feedbacks: int = 0 + + model_config = ConfigDict(extra="forbid") + + +class AppMessageExportService: + @staticmethod + def validate_export_filename(filename: str) -> str: + normalized = filename.strip() + if not normalized: + raise ValueError("--filename must not be empty.") + + normalized_lower = normalized.lower() + if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES): + raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.") + + if normalized.startswith("/"): + raise ValueError("--filename must be a relative path; absolute paths are not allowed.") + + if "\\" in normalized: + raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.") + + if "//" in normalized: + raise ValueError("--filename must not contain empty path segments ('//').") + + if len(normalized) > MAX_FILENAME_BASE_LENGTH: + raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.") + + for ch in normalized: + if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127: + raise ValueError("--filename must not contain control characters or NUL.") + + parts = PurePosixPath(normalized).parts + if not parts: + raise ValueError("--filename must include a file name.") + + if any(part in (".", "..") for part in parts): + raise ValueError("--filename must not contain '.' or '..' path segments.") + + return normalized + + @property + def output_gz_name(self) -> str: + return f"{self._filename_base}.jsonl.gz" + + @property + def output_jsonl_name(self) -> str: + return f"{self._filename_base}.jsonl" + + def __init__( + self, + app_id: str, + end_before: datetime.datetime, + filename: str, + *, + start_from: datetime.datetime | None = None, + batch_size: int = 1000, + use_cloud_storage: bool = False, + dry_run: bool = False, + ) -> None: + if start_from and start_from >= end_before: + raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})") + + self._app_id = app_id + self._end_before = end_before + self._start_from = start_from + self._filename_base = self.validate_export_filename(filename) + self._batch_size = batch_size + self._use_cloud_storage = use_cloud_storage + self._dry_run = dry_run + + def run(self) -> AppMessageExportStats: + stats = AppMessageExportStats() + + logger.info( + "export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s", + self._app_id, + self._start_from, + self._end_before, + self._dry_run, + self._use_cloud_storage, + self.output_gz_name, + ) + + if self._dry_run: + for _ in self._iter_records_with_stats(stats): + pass + self._finalize_stats(stats) + return stats + + if self._use_cloud_storage: + self._export_to_cloud(stats) + else: + self._export_to_local(stats) + + self._finalize_stats(stats) + return stats + + def iter_records(self) -> Generator[AppMessageExportRecord, None, None]: + for batch in self._iter_record_batches(): + yield from batch + + @staticmethod + def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None: + with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz: + for record in records: + gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n") + + def _export_to_local(self, stats: AppMessageExportStats) -> None: + output_path = Path.cwd() / self.output_gz_name + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("wb") as output_file: + self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file) + + def _export_to_cloud(self, stats: AppMessageExportStats) -> None: + with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp: + self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp)) + tmp.seek(0) + data = tmp.read() + + storage.save(self.output_gz_name, data) + logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name) + + def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]: + for record in self.iter_records(): + self._update_stats(stats, record) + yield record + + @staticmethod + def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None: + stats.total_messages += 1 + if record.feedback: + stats.messages_with_feedback += 1 + stats.total_feedbacks += len(record.feedback) + + def _finalize_stats(self, stats: AppMessageExportStats) -> None: + if stats.total_messages == 0: + stats.batches = 0 + return + stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size + + def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]: + cursor: tuple[datetime.datetime, str] | None = None + while True: + rows, cursor = self._fetch_batch(cursor) + if not rows: + break + + message_ids = [str(row.id) for row in rows] + feedbacks_map = self._fetch_feedbacks(message_ids) + yield [self._build_record(row, feedbacks_map) for row in rows] + + def _fetch_batch( + self, cursor: tuple[datetime.datetime, str] | None + ) -> tuple[list[Any], tuple[datetime.datetime, str] | None]: + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select( + Message.id, + Message.conversation_id, + Message.query, + Message.answer, + Message._inputs, # pyright: ignore[reportPrivateUsage] + Message.message_metadata, + Message.created_at, + ) + .where( + Message.app_id == self._app_id, + Message.created_at < self._end_before, + ) + .order_by(Message.created_at, Message.id) + .limit(self._batch_size) + ) + + if self._start_from: + stmt = stmt.where(Message.created_at >= self._start_from) + + if cursor: + stmt = stmt.where( + tuple_(Message.created_at, Message.id) + > tuple_( + sa.literal(cursor[0], type_=sa.DateTime()), + sa.literal(cursor[1], type_=Message.id.type), + ) + ) + + rows = list(session.execute(stmt).all()) + + if not rows: + return [], cursor + + last = rows[-1] + return rows, (last.created_at, last.id) + + def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]: + if not message_ids: + return {} + + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(MessageFeedback) + .where( + MessageFeedback.message_id.in_(message_ids), + MessageFeedback.from_source == "user", + ) + .order_by(MessageFeedback.message_id, MessageFeedback.created_at) + ) + feedbacks = list(session.scalars(stmt).all()) + + result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list) + for feedback in feedbacks: + result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict())) + return result + + @staticmethod + def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord: + retriever_resources: list[Any] = [] + if row.message_metadata: + try: + metadata = json.loads(row.message_metadata) + value = metadata.get("retriever_resources", []) + if isinstance(value, list): + retriever_resources = value + except (json.JSONDecodeError, TypeError): + pass + + message_id = str(row.id) + return AppMessageExportRecord( + conversation_id=str(row.conversation_id), + message_id=message_id, + query=row.query, + answer=row.answer, + inputs=row._inputs if isinstance(row._inputs, dict) else {}, + retriever_resources=retriever_resources, + feedback=feedbacks_map.get(message_id, []), + ) diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py index f7836a2b14..04265817d7 100644 --- a/api/services/retention/conversation/messages_clean_service.py +++ b/api/services/retention/conversation/messages_clean_service.py @@ -12,6 +12,7 @@ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.model import ( App, AppAnnotationHitHistory, @@ -142,7 +143,7 @@ class MessagesCleanService: if batch_size <= 0: raise ValueError(f"batch_size ({batch_size}) must be greater than 0") - end_before = datetime.datetime.now() - datetime.timedelta(days=days) + end_before = naive_utc_now() - datetime.timedelta(days=days) logger.info( "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s", diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index e323b3cda9..b6e5367023 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -33,6 +33,8 @@ logger = logging.getLogger(__name__) class ToolTransformService: + _MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10 + @classmethod def get_tool_provider_icon_url( cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str] @@ -435,6 +437,46 @@ class ToolTransformService: :return: list of ToolParameter instances """ + def resolve_property_type(prop: dict[str, Any], depth: int = 0) -> str: + """ + Resolve a JSON schema property type while guarding against cyclic or deeply nested unions. + """ + if depth >= ToolTransformService._MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH: + return "string" + prop_type = prop.get("type") + if isinstance(prop_type, list): + non_null_types = [type_name for type_name in prop_type if type_name != "null"] + if non_null_types: + return non_null_types[0] + if prop_type: + return "string" + elif isinstance(prop_type, str): + if prop_type == "null": + return "string" + return prop_type + + for union_key in ("anyOf", "oneOf"): + union_schemas = prop.get(union_key) + if not isinstance(union_schemas, list): + continue + + for union_schema in union_schemas: + if not isinstance(union_schema, dict): + continue + union_type = resolve_property_type(union_schema, depth + 1) + if union_type != "null": + return union_type + + all_of_schemas = prop.get("allOf") + if isinstance(all_of_schemas, list): + for all_of_schema in all_of_schemas: + if not isinstance(all_of_schema, dict): + continue + all_of_type = resolve_property_type(all_of_schema, depth + 1) + if all_of_type != "null": + return all_of_type + return "string" + def create_parameter( name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None ) -> ToolParameter: @@ -461,10 +503,7 @@ class ToolTransformService: parameters = [] for name, prop in props.items(): current_description = prop.get("description", "") - prop_type = prop.get("type", "string") - - if isinstance(prop_type, list): - prop_type = prop_type[0] + prop_type = resolve_property_type(prop) if prop_type in TYPE_MAPPING: prop_type = TYPE_MAPPING[prop_type] input_schema = prop if prop_type in COMPLEX_TYPES else None diff --git a/api/services/website_service.py b/api/services/website_service.py index fe48c3b08e..15ec4657d9 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -124,7 +124,7 @@ class WebsiteService: if provider == "firecrawl": plugin_id = "langgenius/firecrawl_datasource" elif provider == "watercrawl": - plugin_id = "langgenius/watercrawl_datasource" + plugin_id = "watercrawl/watercrawl_datasource" elif provider == "jinareader": plugin_id = "langgenius/jina_datasource" else: diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 11edcf151f..b3f36d8f44 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -1,9 +1,10 @@ import logging import time -from collections.abc import Callable, Sequence +from collections.abc import Sequence +from typing import Any, Protocol import click -from celery import shared_task +from celery import current_app, shared_task from configs import dify_config from core.db.session_factory import session_factory @@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) +class CeleryTaskLike(Protocol): + def delay(self, *args: Any, **kwargs: Any) -> Any: ... + + def apply_async(self, *args: Any, **kwargs: Any) -> Any: ... + + @shared_task(queue="dataset") def document_indexing_task(dataset_id: str, document_ids: list): """ @@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): def _document_indexing_with_tenant_queue( - tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] -): + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike +) -> None: try: _document_indexing(dataset_id, document_ids) except Exception: @@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue( logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) if next_tasks: - for next_task in next_tasks: - document_task = DocumentTask(**next_task) - # Process the next waiting task - # Keep the flag set to indicate a task is running - tenant_isolated_task_queue.set_task_waiting_time() - task_func.delay( # type: ignore - tenant_id=document_task.tenant_id, - dataset_id=document_task.dataset_id, - document_ids=document_task.document_ids, - ) + with current_app.producer_or_acquire() as producer: # type: ignore + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.apply_async( + kwargs={ + "tenant_id": document_task.tenant_id, + "dataset_id": document_task.dataset_id, + "document_ids": document_task.document_ids, + }, + producer=producer, + ) + else: # No more waiting tasks, clear the flag tenant_isolated_task_queue.delete_task_key() diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index e4273e16b5..6493833edc 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -14,7 +14,7 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) -@shared_task(queue="dataset") +@shared_task(queue="dataset_summary") def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None): """ Async generate summary index for document segments. diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 6ad04aab0d..5d201bd801 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -6,7 +6,6 @@ import typing import click from celery import shared_task -from core.helper.marketplace import record_install_plugin_event from core.plugin.entities.marketplace import MarketplacePluginSnapshot from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller @@ -166,7 +165,6 @@ def process_tenant_plugin_autoupgrade_check_task( # execute upgrade new_unique_identifier = manifest.latest_package_identifier - record_install_plugin_event(new_unique_identifier) click.echo( click.style( f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}", diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 093342d1a3..52f66dddb8 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -3,12 +3,13 @@ import json import logging import time import uuid -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor +from itertools import islice from typing import Any import click -from celery import shared_task # type: ignore +from celery import group, shared_task from flask import current_app, g from sqlalchemy.orm import Session, sessionmaker @@ -27,6 +28,11 @@ from services.file_service import FileService logger = logging.getLogger(__name__) +def chunked(iterable: Sequence, size: int): + it = iter(iterable) + return iter(lambda: list(islice(it, size)), []) + + @shared_task(queue="pipeline") def rag_pipeline_run_task( rag_pipeline_invoke_entities_file_id: str, @@ -83,16 +89,24 @@ def rag_pipeline_run_task( logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: - for next_file_id in next_file_ids: - # Process the next waiting task - # Keep the flag set to indicate a task is running - tenant_isolated_task_queue.set_task_waiting_time() - rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") - if isinstance(next_file_id, bytes) - else next_file_id, - tenant_id=tenant_id, - ) + for batch in chunked(next_file_ids, 100): + jobs = [] + for next_file_id in batch: + tenant_isolated_task_queue.set_task_waiting_time() + + file_id = ( + next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id + ) + + jobs.append( + rag_pipeline_run_task.s( + rag_pipeline_invoke_entities_file_id=file_id, + tenant_id=tenant_id, + ) + ) + + if jobs: + group(jobs).apply_async() else: # No more waiting tasks, clear the flag tenant_isolated_task_queue.delete_task_key() diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index cf8988d13e..39c2f4103e 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -16,7 +16,7 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) -@shared_task(queue="dataset") +@shared_task(queue="dataset_summary") def regenerate_summary_index_task( dataset_id: str, regenerate_reason: str = "summary_model_changed", diff --git a/api/tests/integration_tests/controllers/console/app/test_description_validation.py b/api/tests/integration_tests/controllers/console/app/test_description_validation.py index 8160807e48..f36c596eb8 100644 --- a/api/tests/integration_tests/controllers/console/app/test_description_validation.py +++ b/api/tests/integration_tests/controllers/console/app/test_description_validation.py @@ -5,14 +5,10 @@ This test module validates the 400-character limit enforcement for App descriptions across all creation and editing endpoints. """ -import os import sys import pytest -# Add the API root to Python path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) - class TestAppDescriptionValidationUnit: """Unit tests for description validation function""" diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index b4779ebcdd..2aca9f5157 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db @@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), + http_client=MagicMock(spec=HttpClientProtocol), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f70bf46979..23cb56d2a5 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -8,6 +8,7 @@ from core.workflow.node_factory import DifyNodeFactory from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -55,11 +56,14 @@ def init_tool_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + node = ToolNode( id=str(uuid.uuid4()), config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + tool_file_manager_factory=tool_file_manager_factory, ) return node diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index d6d2d30305..2a23f1ea7d 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -10,8 +10,11 @@ more reliable and realistic test scenarios. import logging import os from collections.abc import Generator +from contextlib import contextmanager from pathlib import Path +from typing import Protocol, TypeVar +import psycopg2 import pytest from flask import Flask from flask.testing import FlaskClient @@ -31,6 +34,25 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level logger = logging.getLogger(__name__) +class _CloserProtocol(Protocol): + """_Closer is any type which implement the close() method.""" + + def close(self): + """close the current object, release any external resouece (file, transaction, connection etc.) + associated with it. + """ + pass + + +_Closer = TypeVar("_Closer", bound=_CloserProtocol) + + +@contextmanager +def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]: + yield closer + closer.close() + + class DifyTestContainers: """ Manages all test containers required for Dify integration tests. @@ -97,45 +119,28 @@ class DifyTestContainers: wait_for_logs(self.postgres, "is ready to accept connections", timeout=30) logger.info("PostgreSQL container is ready and accepting connections") - # Install uuid-ossp extension for UUID generation - logger.info("Installing uuid-ossp extension...") - try: - import psycopg2 - - conn = psycopg2.connect( - host=db_host, - port=db_port, - user=self.postgres.username, - password=self.postgres.password, - database=self.postgres.dbname, - ) - conn.autocommit = True - cursor = conn.cursor() - cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') - cursor.close() - conn.close() + conn = psycopg2.connect( + host=db_host, + port=db_port, + user=self.postgres.username, + password=self.postgres.password, + database=self.postgres.dbname, + ) + conn.autocommit = True + with _auto_close(conn): + with conn.cursor() as cursor: + # Install uuid-ossp extension for UUID generation + logger.info("Installing uuid-ossp extension...") + cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') logger.info("uuid-ossp extension installed successfully") - except Exception as e: - logger.warning("Failed to install uuid-ossp extension: %s", e) - # Create plugin database for dify-plugin-daemon - logger.info("Creating plugin database...") - try: - conn = psycopg2.connect( - host=db_host, - port=db_port, - user=self.postgres.username, - password=self.postgres.password, - database=self.postgres.dbname, - ) - conn.autocommit = True - cursor = conn.cursor() - cursor.execute("CREATE DATABASE dify_plugin;") - cursor.close() - conn.close() + # NOTE: We cannot use `with conn.cursor() as cursor:` as it will wrap the statement + # inside a transaction. However, the `CREATE DATABASE` statement cannot run inside a transaction block. + with _auto_close(conn.cursor()) as cursor: + # Create plugin database for dify-plugin-daemon + logger.info("Creating plugin database...") + cursor.execute("CREATE DATABASE dify_plugin;") logger.info("Plugin database created successfully") - except Exception as e: - logger.warning("Failed to create plugin database: %s", e) # Set up storage environment variables os.environ.setdefault("STORAGE_TYPE", "opendal") @@ -258,23 +263,16 @@ class DifyTestContainers: containers = [self.redis, self.postgres, self.dify_sandbox, self.dify_plugin_daemon] for container in containers: if container: - try: - container_name = container.image - logger.info("Stopping container: %s", container_name) - container.stop() - logger.info("Successfully stopped container: %s", container_name) - except Exception as e: - # Log error but don't fail the test cleanup - logger.warning("Failed to stop container %s: %s", container, e) + container_name = container.image + logger.info("Stopping container: %s", container_name) + container.stop() + logger.info("Successfully stopped container: %s", container_name) # Stop and remove the network if self.network: - try: - logger.info("Removing Docker network...") - self.network.remove() - logger.info("Successfully removed Docker network") - except Exception as e: - logger.warning("Failed to remove Docker network: %s", e) + logger.info("Removing Docker network...") + self.network.remove() + logger.info("Successfully removed Docker network") self._containers_started = False logger.info("All test containers stopped and cleaned up successfully") diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index e5d3655771..d783a08233 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -8,6 +8,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest from models.dataset import Dataset, Document from services.account_service import AccountService, TenantService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestGetAvailableDatasetsIntegration: @@ -22,7 +23,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -83,7 +84,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -136,7 +137,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -189,7 +190,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -252,7 +253,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -286,7 +287,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company()) tenant1 = account1.current_tenant @@ -295,7 +296,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company()) tenant2 = account2.current_tenant @@ -362,7 +363,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -384,7 +385,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -445,7 +446,7 @@ class TestKnowledgeRetrievalIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -513,7 +514,7 @@ class TestKnowledgeRetrievalIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -561,7 +562,7 @@ class TestKnowledgeRetrievalIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/helpers/__init__.py b/api/tests/test_containers_integration_tests/helpers/__init__.py index 40d03889a9..0b753abd1f 100644 --- a/api/tests/test_containers_integration_tests/helpers/__init__.py +++ b/api/tests/test_containers_integration_tests/helpers/__init__.py @@ -1 +1,24 @@ """Helper utilities for integration tests.""" + +import re + + +def generate_valid_password(fake, length: int = 12) -> str: + """Generate a password that always satisfies the project's password validation rules. + + The password validation rule in ``api/libs/password.py`` requires passwords to + contain **both letters and digits** with a minimum length of 8: + + ``^(?=.*[a-zA-Z])(?=.*\\d).{8,}$`` + + ``Faker.password()`` does **not** guarantee that the generated password will + contain both character types, which can cause intermittent test failures. + + This helper re-generates until the result is valid (typically first attempt). + """ + for _ in range(100): + pwd = fake.password(length=length) + if re.search(r"[a-zA-Z]", pwd) and re.search(r"\d", pwd): + return pwd + # Fallback: should never be reached in practice + return fake.password(length=max(length - 2, 6)) + "a1" diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 8595f5bf14..9354a3ac35 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -20,6 +20,7 @@ from services.errors.account import ( TenantNotFoundError, ) from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAccountService: @@ -53,7 +54,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -133,7 +134,7 @@ class TestAccountService: email=email, name=name, interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) def test_create_account_email_in_freeze( @@ -145,7 +146,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True @@ -169,7 +170,7 @@ class TestAccountService: """ fake = Faker() email = fake.email() - password = fake.password(length=12) + password = generate_valid_password(fake) with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) @@ -180,7 +181,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -208,8 +209,8 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - correct_password = fake.password(length=12) - wrong_password = fake.password(length=12) + correct_password = generate_valid_password(fake) + wrong_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -234,7 +235,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - new_password = fake.password(length=12) + new_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -267,7 +268,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -297,8 +298,8 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - old_password = fake.password(length=12) - new_password = fake.password(length=12) + old_password = generate_valid_password(fake) + new_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -327,9 +328,9 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - old_password = fake.password(length=12) - wrong_password = fake.password(length=12) - new_password = fake.password(length=12) + old_password = generate_valid_password(fake) + wrong_password = generate_valid_password(fake) + new_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -354,7 +355,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - old_password = fake.password(length=12) + old_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -378,7 +379,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies[ @@ -412,7 +413,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies[ @@ -437,7 +438,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies[ @@ -535,7 +536,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -563,7 +564,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) updated_name = fake.name() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -592,7 +593,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -615,7 +616,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -645,7 +646,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -684,7 +685,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -714,7 +715,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -747,7 +748,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -792,7 +793,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -825,7 +826,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -864,7 +865,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -892,7 +893,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -926,7 +927,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -957,7 +958,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -997,7 +998,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -1043,7 +1044,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -1080,7 +1081,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -1110,7 +1111,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -1139,7 +1140,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) wrong_code = fake.numerify(text="######") # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -1259,7 +1260,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1291,10 +1292,10 @@ class TestTenantService: tenant_name = fake.company() email1 = fake.email() name1 = fake.name() - password1 = fake.password(length=12) + password1 = generate_valid_password(fake) email2 = fake.email() name2 = fake.name() - password2 = fake.password(length=12) + password2 = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1332,7 +1333,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1364,7 +1365,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant1_name = fake.company() tenant2_name = fake.company() # Setup mocks @@ -1403,7 +1404,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies[ @@ -1441,7 +1442,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1466,7 +1467,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant1_name = fake.company() tenant2_name = fake.company() # Setup mocks @@ -1507,7 +1508,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1534,7 +1535,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies[ @@ -1562,10 +1563,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1631,7 +1632,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1664,10 +1665,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1705,7 +1706,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) invalid_action = "invalid_action_that_doesnt_exist" # Setup mocks mock_external_service_dependencies[ @@ -1738,7 +1739,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1770,10 +1771,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1829,7 +1830,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1861,10 +1862,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) non_member_email = fake.email() non_member_name = fake.name() - non_member_password = fake.password(length=12) + non_member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1900,10 +1901,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1949,10 +1950,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -2006,10 +2007,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -2071,7 +2072,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) workspace_name = fake.company() # Setup mocks mock_external_service_dependencies[ @@ -2110,7 +2111,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) existing_tenant_name = fake.company() new_workspace_name = fake.company() # Setup mocks @@ -2151,7 +2152,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) workspace_name = fake.company() # Setup mocks to disable workspace creation mock_external_service_dependencies[ @@ -2178,13 +2179,13 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) normal_email = fake.email() normal_name = fake.name() - normal_password = fake.password(length=12) + normal_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -2244,13 +2245,13 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) operator_email = fake.email() operator_name = fake.name() - operator_password = fake.password(length=12) + operator_password = generate_valid_password(fake) normal_email = fake.email() normal_name = fake.name() - normal_password = fake.password(length=12) + normal_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -2351,7 +2352,7 @@ class TestRegisterService: fake = Faker() admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2399,7 +2400,7 @@ class TestRegisterService: fake = Faker() admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2440,7 +2441,7 @@ class TestRegisterService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2531,7 +2532,7 @@ class TestRegisterService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2576,7 +2577,7 @@ class TestRegisterService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2614,7 +2615,7 @@ class TestRegisterService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2653,7 +2654,7 @@ class TestRegisterService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2690,7 +2691,7 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) new_member_email = fake.email() language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks @@ -2760,10 +2761,10 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) existing_member_email = fake.email() existing_member_name = fake.name() - existing_member_password = fake.password(length=12) + existing_member_password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2824,10 +2825,10 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) existing_pending_member_email = fake.email() existing_pending_member_name = fake.name() - existing_pending_member_password = fake.password(length=12) + existing_pending_member_password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2914,10 +2915,10 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) already_in_tenant_email = fake.email() already_in_tenant_name = fake.name() - already_in_tenant_password = fake.password(length=12) + already_in_tenant_password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2967,7 +2968,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3011,7 +3012,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3058,7 +3059,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3101,7 +3102,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3144,7 +3145,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3212,7 +3213,7 @@ class TestRegisterService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) invalid_tenant_id = fake.uuid4() token = fake.uuid4() # Setup mocks @@ -3263,7 +3264,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) token = fake.uuid4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -3313,7 +3314,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) token = fake.uuid4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 45839fd463..4759d244fd 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -11,6 +11,7 @@ from models.model import AppModelConfig, Conversation, EndUser, Message, Message from services.account_service import AccountService, TenantService from services.agent_service import AgentService from services.app_service import AppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAgentService: @@ -111,7 +112,7 @@ class TestAgentService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 004d643955..a260d823a2 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -9,6 +9,7 @@ from models import Account from models.model import MessageAnnotation from services.annotation_service import AppAnnotationService from services.app_service import AppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAnnotationService: @@ -78,7 +79,7 @@ class TestAnnotationService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index b8bf8543bc..7ce7357b41 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from models.api_based_extension import APIBasedExtension from services.account_service import AccountService, TenantService from services.api_based_extension_service import APIBasedExtensionService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAPIBasedExtensionService: @@ -55,7 +56,7 @@ class TestAPIBasedExtensionService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index e2a450b90c..8a362e1f5e 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -9,6 +9,7 @@ from models.model import App, AppModelConfig from services.account_service import AccountService, TenantService from services.app_dsl_service import AppDslService, ImportMode, ImportStatus from services.app_service import AppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAppDslService: @@ -89,7 +90,7 @@ class TestAppDslService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 787a99f3e8..5155d50b0e 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -10,6 +10,7 @@ from models.model import EndUser from models.workflow import Workflow from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAppGenerateService: @@ -147,7 +148,7 @@ class TestAppGenerateService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index fc3b20aaae..d79f80c009 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -8,6 +8,7 @@ from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService +from tests.test_containers_integration_tests.helpers import generate_valid_password # Delay import of AppService to avoid circular dependency # from services.app_service import AppService @@ -56,7 +57,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -112,7 +113,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -155,7 +156,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -203,7 +204,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -259,7 +260,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -334,7 +335,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -404,7 +405,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -473,7 +474,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -526,7 +527,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -585,7 +586,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -645,7 +646,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -705,7 +706,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -756,7 +757,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -808,7 +809,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -868,7 +869,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -907,7 +908,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -947,7 +948,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -997,7 +998,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -1039,7 +1040,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py new file mode 100644 index 0000000000..44525e0036 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -0,0 +1,497 @@ +""" +Container-backed integration tests for dataset permission services on the real SQL path. + +This module exercises persisted DatasetPermission rows and dataset permission +checks with testcontainers-backed infrastructure instead of database-chain mocks. +""" + +from uuid import uuid4 + +import pytest + +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + Dataset, + DatasetPermission, + DatasetPermissionEnum, +) +from services.dataset_service import DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionTestDataFactory: + """Create persisted entities and request payloads for dataset permission integration tests.""" + + @staticmethod + def create_account_with_tenant( + role: TenantAccountRole = TenantAccountRole.NORMAL, + tenant: Tenant | None = None, + ) -> tuple[Account, Tenant]: + """Create a real account and tenant with specified role.""" + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + if tenant is None: + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db.session.add_all([account, tenant]) + else: + db.session.add(account) + + db.session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db.session.add(join) + db.session.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + tenant_id: str, + created_by: str, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + name: str = "Test Dataset", + ) -> Dataset: + """Create a real dataset with specified attributes.""" + dataset = Dataset( + tenant_id=tenant_id, + name=name, + description="desc", + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=created_by, + permission=permission, + provider="vendor", + retrieval_model={"top_k": 2}, + ) + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_dataset_permission( + dataset_id: str, + account_id: str, + tenant_id: str, + has_permission: bool = True, + ) -> DatasetPermission: + """Create a real DatasetPermission instance.""" + permission = DatasetPermission( + dataset_id=dataset_id, + account_id=account_id, + tenant_id=tenant_id, + has_permission=has_permission, + ) + db.session.add(permission) + db.session.commit() + return permission + + @staticmethod + def build_user_list_payload(user_ids: list[str]) -> list[dict[str, str]]: + """Build the request payload shape used by partial-member list updates.""" + return [{"user_id": user_id} for user_id in user_ids] + + +class TestDatasetPermissionServiceGetPartialMemberList: + """Verify partial-member list reads against persisted DatasetPermission rows.""" + + def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers): + """ + Test retrieving partial member list with multiple members. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + user_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + user_3, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + expected_account_ids = [user_1.id, user_2.id, user_3.id] + for account_id in expected_account_ids: + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, account_id, tenant.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert set(result) == set(expected_account_ids) + assert len(result) == 3 + + def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers): + """ + Test retrieving partial member list with single member. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + expected_account_ids = [user.id] + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert set(result) == set(expected_account_ids) + assert len(result) == 1 + + def test_get_dataset_partial_member_list_empty(self, db_session_with_containers): + """ + Test retrieving partial member list when no members exist. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert result == [] + assert len(result) == 0 + + +class TestDatasetPermissionServiceUpdatePartialMemberList: + """Verify partial-member list updates against persisted DatasetPermission rows.""" + + def test_update_partial_member_list_add_new_members(self, db_session_with_containers): + """ + Test adding new partial members to a dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + user_list = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert set(result) == {member_1.id, member_2.id} + + def test_update_partial_member_list_replace_existing(self, db_session_with_containers): + """ + Test replacing existing partial members with new ones. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + old_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + old_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + new_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + new_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + old_users = DatasetPermissionTestDataFactory.build_user_list_payload([old_member_1.id, old_member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, old_users) + + new_users = DatasetPermissionTestDataFactory.build_user_list_payload([new_member_1.id, new_member_2.id]) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, new_users) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert set(result) == {new_member_1.id, new_member_2.id} + + def test_update_partial_member_list_empty_list(self, db_session_with_containers): + """ + Test updating with empty member list (clearing all members). + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, []) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers): + """ + Test error handling and rollback on database error. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + existing_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + replacement_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + DatasetPermissionTestDataFactory.build_user_list_payload([existing_member.id]), + ) + user_list = DatasetPermissionTestDataFactory.build_user_list_payload([replacement_member.id]) + rollback_called = {"count": 0} + original_rollback = db.session.rollback + + # Act / Assert + with pytest.MonkeyPatch.context() as mp: + + def _raise_commit(): + raise Exception("Database connection error") + + def _rollback_and_mark(): + rollback_called["count"] += 1 + original_rollback() + + mp.setattr("services.dataset_service.db.session.commit", _raise_commit) + mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + with pytest.raises(Exception, match="Database connection error"): + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert rollback_called["count"] == 1 + assert result == [existing_member.id] + assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 1 + + +class TestDatasetPermissionServiceClearPartialMemberList: + """Verify partial-member clearing against persisted DatasetPermission rows.""" + + def test_clear_partial_member_list_success(self, db_session_with_containers): + """ + Test successful clearing of partial member list. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + + # Act + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_clear_partial_member_list_empty_list(self, db_session_with_containers): + """ + Test clearing partial member list when no members exist. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + # Act + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers): + """ + Test error handling and rollback on database error. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + rollback_called = {"count": 0} + original_rollback = db.session.rollback + + # Act / Assert + with pytest.MonkeyPatch.context() as mp: + + def _raise_commit(): + raise Exception("Database connection error") + + def _rollback_and_mark(): + rollback_called["count"] += 1 + original_rollback() + + mp.setattr("services.dataset_service.db.session.commit", _raise_commit) + mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + with pytest.raises(Exception, match="Database connection error"): + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert rollback_called["count"] == 1 + assert set(result) == {member_1.id, member_2.id} + assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 2 + + +class TestDatasetServiceCheckDatasetPermission: + """Verify dataset access checks against persisted partial-member permissions.""" + + def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): + """ + Test that user with explicit permission can access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act (should not raise) + DatasetService.check_dataset_permission(dataset, user) + + # Assert + permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert user.id in permissions + + def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers): + """ + Test error when user without permission tries to access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + # Act & Assert + with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): + DatasetService.check_dataset_permission(dataset, user) + + +class TestDatasetServiceCheckDatasetOperatorPermission: + """Verify operator permission checks against persisted partial-member permissions.""" + + def test_check_dataset_operator_permission_partial_members_with_permission_success( + self, db_session_with_containers + ): + """ + Test that user with explicit permission can access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act (should not raise) + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + # Assert + permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert user.id in permissions + + def test_check_dataset_operator_permission_partial_members_without_permission_error( + self, db_session_with_containers + ): + """ + Test error when user without permission tries to access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + # Act & Assert + with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py new file mode 100644 index 0000000000..c47e35791d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -0,0 +1,244 @@ +"""Container-backed integration tests for DatasetService.delete_dataset real SQL paths.""" + +from unittest.mock import patch +from uuid import uuid4 + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document +from services.dataset_service import DatasetService + + +class DatasetDeleteIntegrationDataFactory: + """Create persisted entities used by delete_dataset integration tests.""" + + @staticmethod + def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]: + """Persist an owner account, tenant, and tenant join for dataset deletion tests.""" + account = Account( + email=f"owner-{uuid4()}@example.com", + name="Owner", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant( + name=f"tenant-{uuid4()}", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + db_session_with_containers, + tenant_id: str, + created_by: str, + *, + indexing_technique: str | None, + chunk_structure: str | None, + index_struct: str | None = '{"type": "paragraph"}', + collection_binding_id: str | None = None, + pipeline_id: str | None = None, + ) -> Dataset: + """Persist a dataset with delete_dataset-relevant fields configured.""" + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type="upload_file", + indexing_technique=indexing_technique, + index_struct=index_struct, + created_by=created_by, + collection_binding_id=collection_binding_id, + pipeline_id=pipeline_id, + chunk_structure=chunk_structure, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + tenant_id: str, + dataset_id: str, + created_by: str, + doc_form: str = "text_model", + ) -> Document: + """Persist a document so dataset.doc_form resolves through the real document path.""" + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type="upload_file", + batch=f"batch-{uuid4()}", + name="Document", + created_from="upload_file", + created_by=created_by, + doc_form=doc_form, + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + +class TestDatasetServiceDeleteDataset: + """Integration coverage for DatasetService.delete_dataset using testcontainers.""" + + def test_delete_dataset_with_documents_success(self, db_session_with_containers): + """Delete a dataset with documents and dispatch cleanup through the real signal handler.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique="high_quality", + chunk_structure=None, + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + DatasetDeleteIntegrationDataFactory.create_document( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + created_by=owner.id, + doc_form="text_model", + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_called_once_with( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + dataset.pipeline_id, + ) + + def test_delete_empty_dataset_success(self, db_session_with_containers): + """Delete an empty dataset without scheduling cleanup when both gating fields are absent.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique=None, + chunk_structure=None, + index_struct=None, + collection_binding_id=None, + pipeline_id=None, + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + """Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique=None, + chunk_structure="text_model", + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers): + """Delete a dataset without cleanup when indexing exists but doc_form resolves to None.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique="high_quality", + chunk_structure=None, + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_not_found(self, db_session_with_containers): + """Return False without scheduling cleanup when the target dataset does not exist.""" + # Arrange + owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + missing_dataset_id = str(uuid4()) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(missing_dataset_id, owner) + + # Assert + assert result is False + clean_dataset_delay.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/services/test_message_export_service.py b/api/tests/test_containers_integration_tests/services/test_message_export_service.py new file mode 100644 index 0000000000..200f688ae9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_message_export_service.py @@ -0,0 +1,233 @@ +import datetime +import json +import uuid +from decimal import Decimal + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats + + +class TestAppMessageExportServiceIntegration: + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers: Session): + yield + db_session_with_containers.query(DatasetRetrieverResource).delete() + db_session_with_containers.query(AppAnnotationHitHistory).delete() + db_session_with_containers.query(SavedMessage).delete() + db_session_with_containers.query(MessageFile).delete() + db_session_with_containers.query(MessageAgentThought).delete() + db_session_with_containers.query(MessageChain).delete() + db_session_with_containers.query(MessageAnnotation).delete() + db_session_with_containers.query(MessageFeedback).delete() + db_session_with_containers.query(Message).delete() + db_session_with_containers.query(Conversation).delete() + db_session_with_containers.query(App).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() + + @staticmethod + def _create_app_context(session: Session) -> tuple[App, Conversation]: + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="tester", + interface_language="en-US", + status="active", + ) + session.add(account) + session.flush() + + tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal") + session.add(tenant) + session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + session.add(join) + session.flush() + + app = App( + tenant_id=tenant.id, + name="export-app", + description="integration test app", + mode="chat", + enable_site=True, + enable_api=True, + api_rpm=60, + api_rph=3600, + is_demo=False, + is_public=False, + created_by=account.id, + updated_by=account.id, + ) + session.add(app) + session.flush() + + conversation = Conversation( + app_id=app.id, + app_model_config_id=str(uuid.uuid4()), + model_provider="openai", + model_id="gpt-4o-mini", + mode="chat", + name="conv", + inputs={"seed": 1}, + status="normal", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + session.add(conversation) + session.commit() + return app, conversation + + @staticmethod + def _create_message( + session: Session, + app: App, + conversation: Conversation, + created_at: datetime.datetime, + *, + query: str, + answer: str, + inputs: dict, + message_metadata: str | None, + ) -> Message: + message = Message( + app_id=app.id, + conversation_id=conversation.id, + model_provider="openai", + model_id="gpt-4o-mini", + inputs=inputs, + query=query, + answer=answer, + message=[{"role": "assistant", "content": answer}], + message_tokens=10, + message_unit_price=Decimal("0.001"), + answer_tokens=20, + answer_unit_price=Decimal("0.002"), + total_price=Decimal("0.003"), + currency="USD", + message_metadata=message_metadata, + from_source="api", + from_end_user_id=conversation.from_end_user_id, + created_at=created_at, + ) + session.add(message) + session.flush() + return message + + def test_iter_records_with_stats(self, db_session_with_containers: Session): + app, conversation = self._create_app_context(db_session_with_containers) + + first_inputs = { + "plain": "v1", + "nested": {"a": 1, "b": [1, {"x": True}]}, + "list": ["x", 2, {"y": "z"}], + } + second_inputs = {"other": "value", "items": [1, 2, 3]} + + base_time = datetime.datetime(2026, 2, 25, 10, 0, 0) + first_message = self._create_message( + db_session_with_containers, + app, + conversation, + created_at=base_time, + query="q1", + answer="a1", + inputs=first_inputs, + message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}), + ) + second_message = self._create_message( + db_session_with_containers, + app, + conversation, + created_at=base_time + datetime.timedelta(minutes=1), + query="q2", + answer="a2", + inputs=second_inputs, + message_metadata=None, + ) + + user_feedback_1 = MessageFeedback( + app_id=app.id, + conversation_id=conversation.id, + message_id=first_message.id, + rating="like", + from_source="user", + content="first", + from_end_user_id=conversation.from_end_user_id, + ) + user_feedback_2 = MessageFeedback( + app_id=app.id, + conversation_id=conversation.id, + message_id=first_message.id, + rating="dislike", + from_source="user", + content="second", + from_end_user_id=conversation.from_end_user_id, + ) + admin_feedback = MessageFeedback( + app_id=app.id, + conversation_id=conversation.id, + message_id=first_message.id, + rating="like", + from_source="admin", + content="should-be-filtered", + from_account_id=str(uuid.uuid4()), + ) + db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback]) + user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2) + user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3) + admin_feedback.created_at = base_time + datetime.timedelta(minutes=4) + db_session_with_containers.commit() + + service = AppMessageExportService( + app_id=app.id, + start_from=base_time - datetime.timedelta(minutes=1), + end_before=base_time + datetime.timedelta(minutes=10), + filename="unused", + batch_size=1, + dry_run=True, + ) + stats = AppMessageExportStats() + records = list(service._iter_records_with_stats(stats)) + service._finalize_stats(stats) + + assert len(records) == 2 + assert records[0].message_id == first_message.id + assert records[1].message_id == second_message.id + + assert records[0].inputs == first_inputs + assert records[1].inputs == second_inputs + + assert records[0].retriever_resources == [{"dataset_id": "ds-1"}] + assert records[1].retriever_resources == [] + + assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"] + assert [feedback.content for feedback in records[0].feedback] == ["first", "second"] + assert records[1].feedback == [] + + assert stats.batches == 2 + assert stats.total_messages == 2 + assert stats.messages_with_feedback == 1 + assert stats.total_feedbacks == 2 diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 19a684a58a..a6d7bf27fd 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -13,6 +13,7 @@ from services.errors.message import ( SuggestedQuestionsAfterAnswerDisabledError, ) from services.message_service import MessageService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestMessageService: @@ -95,7 +96,7 @@ class TestMessageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -633,7 +634,7 @@ class TestMessageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company()) diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index e3ec1d1df3..cc403ef5a2 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -8,6 +8,7 @@ from models.model import EndUser, Message from models.web import SavedMessage from services.app_service import AppService from services.saved_message_service import SavedMessageService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestSavedMessageService: @@ -64,7 +65,7 @@ class TestSavedMessageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 912aa3dd2f..e0ea8211f6 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -10,6 +10,7 @@ from core.trigger.entities.entities import Subscription as TriggerSubscriptionEn from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestTriggerProviderService: @@ -75,7 +76,7 @@ class TestTriggerProviderService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index f1e8c152f1..425611744b 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -12,6 +12,7 @@ from models.web import PinnedConversation from services.account_service import AccountService, TenantService from services.app_service import AppService from services.web_conversation_service import WebConversationService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWebConversationService: @@ -69,7 +70,7 @@ class TestWebConversationService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 9a1595d266..4fe65d5803 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -12,6 +12,7 @@ from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAcco from models.model import App, Site from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.webapp_auth_service import WebAppAuthService, WebAppAuthType +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWebAppAuthService: @@ -109,7 +110,7 @@ class TestWebAppAuthService: tuple: (account, tenant, password) - Created account, tenant and password """ fake = Faker() - password = fake.password(length=12) + password = generate_valid_password(fake) # Create account with password import uuid @@ -272,7 +273,7 @@ class TestWebAppAuthService: """ # Arrange: Create banned account fake = Faker() - password = fake.password(length=12) + password = generate_valid_password(fake) unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" account = Account( diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 8f345b9cea..f91e6efb10 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -13,6 +13,7 @@ from models.trigger import AppTrigger, WorkflowWebhookTrigger from models.workflow import Workflow from services.account_service import AccountService, TenantService from services.trigger.webhook_service import WebhookService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWebhookService: @@ -60,7 +61,7 @@ class TestWebhookService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index a3440b6b67..8ab8df2a5a 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -15,6 +15,7 @@ from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency # from services.app_service import AppService from services.workflow_app_service import WorkflowAppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWorkflowAppService: @@ -72,7 +73,7 @@ class TestWorkflowAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -120,7 +121,7 @@ class TestWorkflowAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 38ef3975b7..e080d6ef6b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -15,6 +15,7 @@ from models.workflow import WorkflowRun from services.account_service import AccountService, TenantService from services.app_service import AppService from services.workflow_run_service import WorkflowRunService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWorkflowRunService: @@ -72,7 +73,7 @@ class TestWorkflowRunService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 0b3c1112bd..34906a4e54 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -13,6 +13,7 @@ from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService from services.app_service import AppService from services.tools.workflow_tools_manage_service import WorkflowToolManageService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWorkflowToolManageService: @@ -87,7 +88,7 @@ class TestWorkflowToolManageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 379986c191..3ce199c602 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -15,6 +15,7 @@ from faker import Faker from models.dataset import Dataset, Document, DocumentSegment from services.account_service import AccountService, TenantService from tasks.clean_notion_document_task import clean_notion_document_task +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestCleanNotionDocumentTask: @@ -76,7 +77,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -208,7 +209,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -252,7 +253,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -345,7 +346,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -431,7 +432,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -546,7 +547,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -642,7 +643,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -724,7 +725,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -834,7 +835,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -951,7 +952,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -1054,7 +1055,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 207bdad751..4a62383590 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - task_dispatch_spy.delay.assert_called_once_with( - tenant_id=next_task["tenant_id"], - dataset_id=next_task["dataset_id"], - document_ids=next_task["document_ids"], - ) + # apply_async is used by implementation; assert it was called once with expected kwargs + assert task_dispatch_spy.apply_async.call_count == 1 + call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {}) + assert call_kwargs == { + "tenant_id": next_task["tenant_id"], + "dataset_id": next_task["dataset_id"], + "document_ids": next_task["document_ids"], + } set_waiting_spy.assert_called_once() delete_key_spy.assert_not_called() @@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - task_dispatch_spy.delay.assert_not_called() + task_dispatch_spy.apply_async.assert_not_called() delete_key_spy.assert_called_once() def test_validation_failure_sets_error_status_when_vector_space_at_limit( @@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - task_dispatch_spy.delay.assert_called_once() + task_dispatch_spy.apply_async.assert_called_once() def test_sessions_close_on_successful_indexing( self, @@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - assert task_dispatch_spy.delay.call_count == concurrency_limit + assert task_dispatch_spy.apply_async.call_count == concurrency_limit assert set_waiting_spy.call_count == concurrency_limit def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies): @@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - assert task_dispatch_spy.delay.call_count == 3 + assert task_dispatch_spy.apply_async.call_count == 3 for index, expected_task in enumerate(ordered_tasks): - assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"] + call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {}) + assert call_kwargs.get("document_ids") == expected_task["document_ids"] def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies): """Skip limit checks when billing feature is disabled.""" diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index 58c3ab5509..10c719fb6d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -15,6 +15,7 @@ from faker import Faker from models.dataset import Dataset, Document, DocumentSegment from services.account_service import AccountService, TenantService from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestDealDatasetVectorIndexTask: @@ -61,7 +62,7 @@ class TestDealDatasetVectorIndexTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 4be1180c73..5dc1f6bee0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -762,11 +762,12 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify task function was called for each waiting task - assert mock_task_func.delay.call_count == 1 + assert mock_task_func.apply_async.call_count == 1 # Verify correct parameters for each call - calls = mock_task_func.delay.call_args_list - assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + calls = mock_task_func.apply_async.call_args_list + sent_kwargs = calls[0][1]["kwargs"] + assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} # Verify queue is empty after processing (tasks were pulled) remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added @@ -830,11 +831,15 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None # Verify waiting task was still processed despite core processing error - mock_task_func.delay.assert_called_once() + mock_task_func.apply_async.assert_called_once() # Verify correct parameters for the call - call = mock_task_func.delay.call_args - assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + call = mock_task_func.apply_async.call_args + assert call[1]["kwargs"] == { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": ["waiting-doc-1"], + } # Verify queue is empty after processing (task was pulled) remaining_tasks = queue.pull_tasks(count=10) @@ -896,9 +901,13 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify only tenant1's waiting task was processed - mock_task_func.delay.assert_called_once() - call = mock_task_func.delay.call_args - assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]} + mock_task_func.apply_async.assert_called_once() + call = mock_task_func.apply_async.call_args + assert call[1]["kwargs"] == { + "tenant_id": tenant1_id, + "dataset_id": dataset1_id, + "document_ids": ["tenant1-doc-1"], + } # Verify tenant1's queue is empty remaining_tasks1 = queue1.pull_tasks(count=10) diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index ef7191299a..f01fcc1742 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -1,6 +1,6 @@ import json import uuid -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -388,8 +388,10 @@ class TestRagPipelineRunTasks: # Set the task key to indicate there are waiting tasks (legacy behavior) redis_client.set(legacy_task_key, 1, ex=60 * 60) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the priority task with new code but legacy queue data rag_pipeline_run_task(file_id, tenant.id) @@ -398,13 +400,14 @@ class TestRagPipelineRunTasks: mock_file_service["delete_file"].assert_called_once_with(file_id) assert mock_pipeline_generator.call_count == 1 - # Verify waiting tasks were processed, pull 1 task a time by default - assert mock_delay.call_count == 1 + # Verify waiting tasks were processed via group, pull 1 task a time by default + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0] - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0] + assert first_kwargs.get("tenant_id") == tenant.id # Verify that new code can process legacy queue entries # The new TenantIsolatedTaskQueue should be able to read from the legacy format @@ -446,8 +449,10 @@ class TestRagPipelineRunTasks: waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)] queue.push_tasks(waiting_file_ids) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the regular task rag_pipeline_run_task(file_id, tenant.id) @@ -456,13 +461,14 @@ class TestRagPipelineRunTasks: mock_file_service["delete_file"].assert_called_once_with(file_id) assert mock_pipeline_generator.call_count == 1 - # Verify waiting tasks were processed, pull 1 task a time by default - assert mock_delay.call_count == 1 + # Verify waiting tasks were processed via group.apply_async + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] + assert first_kwargs.get("tenant_id") == tenant.id # Verify queue still has remaining tasks (only 1 was pulled) remaining_tasks = queue.pull_tasks(count=10) @@ -557,8 +563,10 @@ class TestRagPipelineRunTasks: waiting_file_id = str(uuid.uuid4()) queue.push_tasks([waiting_file_id]) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the regular task (should not raise exception) rag_pipeline_run_task(file_id, tenant.id) @@ -569,12 +577,13 @@ class TestRagPipelineRunTasks: assert mock_pipeline_generator.call_count == 1 # Verify waiting task was still processed despite core processing error - mock_delay.assert_called_once() + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert first_kwargs.get("tenant_id") == tenant.id # Verify queue is empty after processing (task was pulled) remaining_tasks = queue.pull_tasks(count=10) @@ -684,8 +693,10 @@ class TestRagPipelineRunTasks: queue1.push_tasks([waiting_file_id1]) queue2.push_tasks([waiting_file_id2]) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the regular task for tenant1 only rag_pipeline_run_task(file_id1, tenant1.id) @@ -694,11 +705,12 @@ class TestRagPipelineRunTasks: assert mock_file_service["delete_file"].call_count == 1 assert mock_pipeline_generator.call_count == 1 - # Verify only tenant1's waiting task was processed - mock_delay.assert_called_once() - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 - assert call_kwargs.get("tenant_id") == tenant1.id + # Verify only tenant1's waiting task was processed (via group) + assert mock_group.return_value.apply_async.called + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 + assert first_kwargs.get("tenant_id") == tenant1.id # Verify tenant1's queue is empty remaining_tasks1 = queue1.pull_tasks(count=10) @@ -913,8 +925,10 @@ class TestRagPipelineRunTasks: waiting_file_id = str(uuid.uuid4()) queue.push_tasks([waiting_file_id]) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act & Assert: Execute the regular task (should raise Exception) with pytest.raises(Exception, match="File not found"): rag_pipeline_run_task(file_id, tenant.id) @@ -924,12 +938,13 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() # Verify waiting task was still processed despite file error - mock_delay.assert_called_once() + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert first_kwargs.get("tenant_id") == tenant.id # Verify queue is empty after processing (task was pulled) remaining_tasks = queue.pull_tasks(count=10) diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py index 9c1fd5e0ec..e3832fb2ef 100644 --- a/api/tests/test_containers_integration_tests/trigger/conftest.py +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -105,18 +105,26 @@ def app_model( class MockCeleryGroup: - """Mock for celery group() function that collects dispatched tasks.""" + """Mock for celery group() function that collects dispatched tasks. + + Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async + (e.g. producer) so production code can pass broker-related options without + breaking tests. + """ def __init__(self) -> None: self.collected: list[dict[str, Any]] = [] self._applied = False + self.last_apply_async_kwargs: dict[str, Any] | None = None def __call__(self, items: Any) -> MockCeleryGroup: self.collected = list(items) return self - def apply_async(self) -> None: + def apply_async(self, **kwargs: Any) -> None: + # Accept arbitrary kwargs like producer to be compatible with Celery self._applied = True + self.last_apply_async_kwargs = kwargs @property def applied(self) -> bool: diff --git a/api/tests/unit_tests/commands/test_clean_expired_messages.py b/api/tests/unit_tests/commands/test_clean_expired_messages.py new file mode 100644 index 0000000000..60173f723d --- /dev/null +++ b/api/tests/unit_tests/commands/test_clean_expired_messages.py @@ -0,0 +1,181 @@ +import datetime +import re +from unittest.mock import MagicMock, patch + +import click +import pytest + +from commands import clean_expired_messages + + +def _mock_service() -> MagicMock: + service = MagicMock() + service.run.return_value = { + "batches": 1, + "total_messages": 10, + "filtered_messages": 5, + "total_deleted": 5, + } + return service + + +def test_absolute_mode_calls_from_time_range(): + policy = object() + service = _mock_service() + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 2, 1, 0, 0, 0) + + with ( + patch("commands.retention.create_message_clean_policy", return_value=policy), + patch("commands.retention.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range, + patch("commands.retention.MessagesCleanService.from_days") as mock_from_days, + ): + clean_expired_messages.callback( + batch_size=200, + graceful_period=21, + start_from=start_from, + end_before=end_before, + from_days_ago=None, + before_days=None, + dry_run=True, + ) + + mock_from_time_range.assert_called_once_with( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=200, + dry_run=True, + ) + mock_from_days.assert_not_called() + + +def test_relative_mode_before_days_only_calls_from_days(): + policy = object() + service = _mock_service() + + with ( + patch("commands.retention.create_message_clean_policy", return_value=policy), + patch("commands.retention.MessagesCleanService.from_days", return_value=service) as mock_from_days, + patch("commands.retention.MessagesCleanService.from_time_range") as mock_from_time_range, + ): + clean_expired_messages.callback( + batch_size=500, + graceful_period=14, + start_from=None, + end_before=None, + from_days_ago=None, + before_days=30, + dry_run=False, + ) + + mock_from_days.assert_called_once_with( + policy=policy, + days=30, + batch_size=500, + dry_run=False, + ) + mock_from_time_range.assert_not_called() + + +def test_relative_mode_with_from_days_ago_calls_from_time_range(): + policy = object() + service = _mock_service() + fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0) + + with ( + patch("commands.retention.create_message_clean_policy", return_value=policy), + patch("commands.retention.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range, + patch("commands.retention.MessagesCleanService.from_days") as mock_from_days, + patch("commands.retention.naive_utc_now", return_value=fixed_now), + ): + clean_expired_messages.callback( + batch_size=1000, + graceful_period=21, + start_from=None, + end_before=None, + from_days_ago=60, + before_days=30, + dry_run=False, + ) + + mock_from_time_range.assert_called_once_with( + policy=policy, + start_from=fixed_now - datetime.timedelta(days=60), + end_before=fixed_now - datetime.timedelta(days=30), + batch_size=1000, + dry_run=False, + ) + mock_from_days.assert_not_called() + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ( + { + "start_from": datetime.datetime(2024, 1, 1), + "end_before": datetime.datetime(2024, 2, 1), + "from_days_ago": None, + "before_days": 30, + }, + "mutually exclusive", + ), + ( + { + "start_from": datetime.datetime(2024, 1, 1), + "end_before": None, + "from_days_ago": None, + "before_days": None, + }, + "Both --start-from and --end-before are required", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": 10, + "before_days": None, + }, + "--from-days-ago must be used together with --before-days", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": None, + "before_days": -1, + }, + "--before-days must be >= 0", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": 30, + "before_days": 30, + }, + "--from-days-ago must be greater than --before-days", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": None, + "before_days": None, + }, + "You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])", + ), + ], +) +def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str): + with pytest.raises(click.UsageError, match=re.escape(message)): + clean_expired_messages.callback( + batch_size=1000, + graceful_period=21, + start_from=kwargs["start_from"], + end_before=kwargs["end_before"], + from_days_ago=kwargs["from_days_ago"], + before_days=kwargs["before_days"], + dry_run=False, + ) diff --git a/api/tests/unit_tests/commands/test_upgrade_db.py b/api/tests/unit_tests/commands/test_upgrade_db.py index 80173f5d46..5aa0313429 100644 --- a/api/tests/unit_tests/commands/test_upgrade_db.py +++ b/api/tests/unit_tests/commands/test_upgrade_db.py @@ -4,6 +4,7 @@ import types from unittest.mock import MagicMock import commands +from commands import system as system_commands from libs.db_migration_lock import LockNotOwnedError, RedisError HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0 @@ -24,11 +25,11 @@ def _invoke_upgrade_db() -> int: def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys): - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234) lock = MagicMock() lock.acquire.return_value = False - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock exit_code = _invoke_upgrade_db() captured = capsys.readouterr() @@ -36,18 +37,18 @@ def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys): assert exit_code == 0 assert "Database migration skipped" in captured.out - commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False) + system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False) lock.acquire.assert_called_once_with(blocking=False) lock.release.assert_not_called() def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys): - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321) lock = MagicMock() lock.acquire.return_value = True lock.release.side_effect = LockNotOwnedError("simulated") - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock def _upgrade(): raise RuntimeError("boom") @@ -60,18 +61,18 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys): assert exit_code == 1 assert "Database migration failed: boom" in captured.out - commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False) + system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False) lock.acquire.assert_called_once_with(blocking=False) lock.release.assert_called_once() def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys): - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999) lock = MagicMock() lock.acquire.return_value = True lock.release.side_effect = LockNotOwnedError("simulated") - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock _install_fake_flask_migrate(monkeypatch, lambda: None) @@ -81,7 +82,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy assert exit_code == 0 assert "Database migration successful!" in captured.out - commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False) + system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False) lock.acquire.assert_called_once_with(blocking=False) lock.release.assert_called_once() @@ -92,11 +93,11 @@ def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys): """ # Use a small TTL so the heartbeat interval triggers quickly. - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) lock = MagicMock() lock.acquire.return_value = True - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock renewed = threading.Event() @@ -120,11 +121,11 @@ def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys): def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys): # Use a small TTL so heartbeat runs during the upgrade call. - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) lock = MagicMock() lock.acquire.return_value = True - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock attempted = threading.Event() diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index d2111ebac8..3f75fd2851 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -32,11 +32,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs") os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") os.environ.setdefault("STORAGE_TYPE", "opendal") -# Add the API directory to Python path to ensure proper imports -import sys - -sys.path.insert(0, PROJECT_DIR) - from core.db.session_factory import configure_session_factory, session_factory from extensions import ext_redis diff --git a/api/tests/unit_tests/controllers/common/test_errors.py b/api/tests/unit_tests/controllers/common/test_errors.py new file mode 100644 index 0000000000..25a9fe5b66 --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_errors.py @@ -0,0 +1,70 @@ +from controllers.common.errors import ( + BlockedFileExtensionError, + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + RemoteFileUploadError, + TooManyFilesError, + UnsupportedFileTypeError, +) + + +class TestFilenameNotExistsError: + def test_defaults(self): + error = FilenameNotExistsError() + + assert error.code == 400 + assert error.description == "The specified filename does not exist." + + +class TestRemoteFileUploadError: + def test_defaults(self): + error = RemoteFileUploadError() + + assert error.code == 400 + assert error.description == "Error uploading remote file." + + +class TestFileTooLargeError: + def test_defaults(self): + error = FileTooLargeError() + + assert error.code == 413 + assert error.error_code == "file_too_large" + assert error.description == "File size exceeded. {message}" + + +class TestUnsupportedFileTypeError: + def test_defaults(self): + error = UnsupportedFileTypeError() + + assert error.code == 415 + assert error.error_code == "unsupported_file_type" + assert error.description == "File type not allowed." + + +class TestBlockedFileExtensionError: + def test_defaults(self): + error = BlockedFileExtensionError() + + assert error.code == 400 + assert error.error_code == "file_extension_blocked" + assert error.description == "The file extension is blocked for security reasons." + + +class TestTooManyFilesError: + def test_defaults(self): + error = TooManyFilesError() + + assert error.code == 400 + assert error.error_code == "too_many_files" + assert error.description == "Only one file is allowed." + + +class TestNoFileUploadedError: + def test_defaults(self): + error = NoFileUploadedError() + + assert error.code == 400 + assert error.error_code == "no_file_uploaded" + assert error.description == "Please upload your file." diff --git a/api/tests/unit_tests/controllers/common/test_file_response.py b/api/tests/unit_tests/controllers/common/test_file_response.py index 2487c362bd..b7500fb7f9 100644 --- a/api/tests/unit_tests/controllers/common/test_file_response.py +++ b/api/tests/unit_tests/controllers/common/test_file_response.py @@ -1,22 +1,95 @@ from flask import Response -from controllers.common.file_response import enforce_download_for_html, is_html_content +from controllers.common.file_response import ( + _normalize_mime_type, + enforce_download_for_html, + is_html_content, +) -class TestFileResponseHelpers: - def test_is_html_content_detects_mime_type(self): +class TestNormalizeMimeType: + def test_returns_empty_string_for_none(self): + assert _normalize_mime_type(None) == "" + + def test_returns_empty_string_for_empty_string(self): + assert _normalize_mime_type("") == "" + + def test_normalizes_mime_type(self): + assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html" + + +class TestIsHtmlContent: + def test_detects_html_via_mime_type(self): mime_type = "text/html; charset=UTF-8" - result = is_html_content(mime_type, filename="file.txt", extension="txt") + result = is_html_content( + mime_type=mime_type, + filename="file.txt", + extension="txt", + ) assert result is True - def test_is_html_content_detects_extension(self): - result = is_html_content("text/plain", filename="report.html", extension=None) + def test_detects_html_via_extension_argument(self): + result = is_html_content( + mime_type="text/plain", + filename=None, + extension="html", + ) assert result is True - def test_enforce_download_for_html_sets_headers(self): + def test_detects_html_via_filename_extension(self): + result = is_html_content( + mime_type="text/plain", + filename="report.html", + extension=None, + ) + + assert result is True + + def test_returns_false_when_no_html_detected_anywhere(self): + """ + Missing negative test: + - MIME type is not HTML + - filename has no HTML extension + - extension argument is not HTML + """ + result = is_html_content( + mime_type="application/json", + filename="data.json", + extension="json", + ) + + assert result is False + + def test_returns_false_when_all_inputs_are_none(self): + result = is_html_content( + mime_type=None, + filename=None, + extension=None, + ) + + assert result is False + + +class TestEnforceDownloadForHtml: + def test_sets_attachment_when_filename_missing(self): + response = Response("payload", mimetype="text/html") + + updated = enforce_download_for_html( + response, + mime_type="text/html", + filename=None, + extension="html", + ) + + assert updated is True + assert response.headers["Content-Disposition"] == "attachment" + assert response.headers["Content-Type"] == "application/octet-stream" + assert response.headers["X-Content-Type-Options"] == "nosniff" + + def test_sets_headers_when_filename_present(self): response = Response("payload", mimetype="text/html") updated = enforce_download_for_html( @@ -27,11 +100,12 @@ class TestFileResponseHelpers: ) assert updated is True - assert "attachment" in response.headers["Content-Disposition"] + assert response.headers["Content-Disposition"].startswith("attachment") + assert "unsafe.html" in response.headers["Content-Disposition"] assert response.headers["Content-Type"] == "application/octet-stream" assert response.headers["X-Content-Type-Options"] == "nosniff" - def test_enforce_download_for_html_no_change_for_non_html(self): + def test_does_not_modify_response_for_non_html_content(self): response = Response("payload", mimetype="text/plain") updated = enforce_download_for_html( diff --git a/api/tests/unit_tests/controllers/common/test_helpers.py b/api/tests/unit_tests/controllers/common/test_helpers.py new file mode 100644 index 0000000000..59c463177c --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_helpers.py @@ -0,0 +1,188 @@ +from uuid import UUID + +import httpx +import pytest + +from controllers.common import helpers +from controllers.common.helpers import FileInfo, guess_file_info_from_response + + +def make_response( + url="https://example.com/file.txt", + headers=None, + content=None, +): + return httpx.Response( + 200, + request=httpx.Request("GET", url), + headers=headers or {}, + content=content or b"", + ) + + +class TestGuessFileInfoFromResponse: + def test_filename_from_url(self): + response = make_response( + url="https://example.com/test.pdf", + content=b"Hello World", + ) + + info = guess_file_info_from_response(response) + + assert info.filename == "test.pdf" + assert info.extension == ".pdf" + assert info.mimetype == "application/pdf" + + def test_filename_from_content_disposition(self): + headers = { + "Content-Disposition": "attachment; filename=myfile.csv", + "Content-Type": "text/csv", + } + response = make_response( + url="https://example.com/", + headers=headers, + content=b"Hello World", + ) + + info = guess_file_info_from_response(response) + + assert info.filename == "myfile.csv" + assert info.extension == ".csv" + assert info.mimetype == "text/csv" + + @pytest.mark.parametrize( + ("magic_available", "expected_ext"), + [ + (True, "txt"), + (False, "bin"), + ], + ) + def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext): + if magic_available: + if helpers.magic is None: + pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant") + else: + monkeypatch.setattr(helpers, "magic", None) + + response = make_response( + url="https://example.com/", + content=b"Hello World", + ) + + info = guess_file_info_from_response(response) + + name, ext = info.filename.split(".") + UUID(name) + assert ext == expected_ext + + def test_mimetype_from_header_when_unknown(self): + headers = {"Content-Type": "application/json"} + response = make_response( + url="https://example.com/file.unknown", + headers=headers, + content=b'{"a": 1}', + ) + + info = guess_file_info_from_response(response) + + assert info.mimetype == "application/json" + + def test_extension_added_when_missing(self): + headers = {"Content-Type": "image/png"} + response = make_response( + url="https://example.com/image", + headers=headers, + content=b"fakepngdata", + ) + + info = guess_file_info_from_response(response) + + assert info.extension == ".png" + assert info.filename.endswith(".png") + + def test_content_length_used_as_size(self): + headers = { + "Content-Length": "1234", + "Content-Type": "text/plain", + } + response = make_response( + url="https://example.com/a.txt", + headers=headers, + content=b"a" * 1234, + ) + + info = guess_file_info_from_response(response) + + assert info.size == 1234 + + def test_size_minus_one_when_header_missing(self): + response = make_response(url="https://example.com/a.txt") + + info = guess_file_info_from_response(response) + + assert info.size == -1 + + def test_fallback_to_bin_extension(self): + headers = {"Content-Type": "application/octet-stream"} + response = make_response( + url="https://example.com/download", + headers=headers, + content=b"\x00\x01\x02\x03", + ) + + info = guess_file_info_from_response(response) + + assert info.extension == ".bin" + assert info.filename.endswith(".bin") + + def test_return_type(self): + response = make_response() + + info = guess_file_info_from_response(response) + + assert isinstance(info, FileInfo) + + +class TestMagicImportWarnings: + @pytest.mark.parametrize( + ("platform_name", "expected_message"), + [ + ("Windows", "pip install python-magic-bin"), + ("Darwin", "brew install libmagic"), + ("Linux", "sudo apt-get install libmagic1"), + ("Other", "install `libmagic`"), + ], + ) + def test_magic_import_warning_per_platform( + self, + monkeypatch, + platform_name, + expected_message, + ): + import builtins + import importlib + + # Force ImportError when "magic" is imported + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "magic": + raise ImportError("No module named magic") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + monkeypatch.setattr("platform.system", lambda: platform_name) + + # Remove helpers so it imports fresh + import sys + + original_helpers = sys.modules.get(helpers.__name__) + sys.modules.pop(helpers.__name__, None) + + try: + with pytest.warns(UserWarning, match="To use python-magic") as warning: + imported_helpers = importlib.import_module(helpers.__name__) + assert expected_message in str(warning[0].message) + finally: + if original_helpers is not None: + sys.modules[helpers.__name__] = original_helpers diff --git a/api/tests/unit_tests/controllers/common/test_schema.py b/api/tests/unit_tests/controllers/common/test_schema.py new file mode 100644 index 0000000000..56c8160f02 --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_schema.py @@ -0,0 +1,189 @@ +import sys +from enum import StrEnum +from unittest.mock import MagicMock, patch + +import pytest +from flask_restx import Namespace +from pydantic import BaseModel + + +class UserModel(BaseModel): + id: int + name: str + + +class ProductModel(BaseModel): + id: int + price: float + + +@pytest.fixture(autouse=True) +def mock_console_ns(): + """Mock the console_ns to avoid circular imports during test collection.""" + mock_ns = MagicMock(spec=Namespace) + mock_ns.models = {} + + # Inject mock before importing schema module + with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}): + yield mock_ns + + +def test_default_ref_template_value(): + from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0 + + assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}" + + +def test_register_schema_model_calls_namespace_schema_model(): + from controllers.common.schema import register_schema_model + + namespace = MagicMock(spec=Namespace) + + register_schema_model(namespace, UserModel) + + namespace.schema_model.assert_called_once() + + model_name, schema = namespace.schema_model.call_args.args + + assert model_name == "UserModel" + assert isinstance(schema, dict) + assert "properties" in schema + + +def test_register_schema_model_passes_schema_from_pydantic(): + from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model + + namespace = MagicMock(spec=Namespace) + + register_schema_model(namespace, UserModel) + + schema = namespace.schema_model.call_args.args[1] + + expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + + assert schema == expected_schema + + +def test_register_schema_models_registers_multiple_models(): + from controllers.common.schema import register_schema_models + + namespace = MagicMock(spec=Namespace) + + register_schema_models(namespace, UserModel, ProductModel) + + assert namespace.schema_model.call_count == 2 + + called_names = [call.args[0] for call in namespace.schema_model.call_args_list] + assert called_names == ["UserModel", "ProductModel"] + + +def test_register_schema_models_calls_register_schema_model(monkeypatch): + from controllers.common.schema import register_schema_models + + namespace = MagicMock(spec=Namespace) + + calls = [] + + def fake_register(ns, model): + calls.append((ns, model)) + + monkeypatch.setattr( + "controllers.common.schema.register_schema_model", + fake_register, + ) + + register_schema_models(namespace, UserModel, ProductModel) + + assert calls == [ + (namespace, UserModel), + (namespace, ProductModel), + ] + + +class StatusEnum(StrEnum): + ACTIVE = "active" + INACTIVE = "inactive" + + +class PriorityEnum(StrEnum): + HIGH = "high" + LOW = "low" + + +def test_get_or_create_model_returns_existing_model(mock_console_ns): + from controllers.common.schema import get_or_create_model + + existing_model = MagicMock() + mock_console_ns.models = {"TestModel": existing_model} + + result = get_or_create_model("TestModel", {"key": "value"}) + + assert result == existing_model + mock_console_ns.model.assert_not_called() + + +def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns): + from controllers.common.schema import get_or_create_model + + mock_console_ns.models = {} + new_model = MagicMock() + mock_console_ns.model.return_value = new_model + field_def = {"name": {"type": "string"}} + + result = get_or_create_model("NewModel", field_def) + + assert result == new_model + mock_console_ns.model.assert_called_once_with("NewModel", field_def) + + +def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns): + from controllers.common.schema import get_or_create_model + + existing_model = MagicMock() + mock_console_ns.models = {"ExistingModel": existing_model} + + result = get_or_create_model("ExistingModel", {"key": "value"}) + + assert result == existing_model + mock_console_ns.model.assert_not_called() + + +def test_register_enum_models_registers_single_enum(): + from controllers.common.schema import register_enum_models + + namespace = MagicMock(spec=Namespace) + + register_enum_models(namespace, StatusEnum) + + namespace.schema_model.assert_called_once() + + model_name, schema = namespace.schema_model.call_args.args + + assert model_name == "StatusEnum" + assert isinstance(schema, dict) + + +def test_register_enum_models_registers_multiple_enums(): + from controllers.common.schema import register_enum_models + + namespace = MagicMock(spec=Namespace) + + register_enum_models(namespace, StatusEnum, PriorityEnum) + + assert namespace.schema_model.call_count == 2 + + called_names = [call.args[0] for call in namespace.schema_model.call_args_list] + assert called_names == ["StatusEnum", "PriorityEnum"] + + +def test_register_enum_models_uses_correct_ref_template(): + from controllers.common.schema import register_enum_models + + namespace = MagicMock(spec=Namespace) + + register_enum_models(namespace, StatusEnum) + + schema = namespace.schema_model.call_args.args[1] + + # Verify the schema contains enum values + assert "enum" in schema or "anyOf" in schema diff --git a/api/tests/unit_tests/controllers/console/app/__init__.py b/api/tests/unit_tests/controllers/console/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_api.py b/api/tests/unit_tests/controllers/console/app/test_annotation_api.py new file mode 100644 index 0000000000..fecbd7f7b0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_api.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from controllers.console.app import annotation as annotation_module + + +def test_annotation_reply_payload_valid(): + """Test AnnotationReplyPayload with valid data.""" + payload = annotation_module.AnnotationReplyPayload( + score_threshold=0.5, + embedding_provider_name="openai", + embedding_model_name="text-embedding-3-small", + ) + assert payload.score_threshold == 0.5 + assert payload.embedding_provider_name == "openai" + assert payload.embedding_model_name == "text-embedding-3-small" + + +def test_annotation_setting_update_payload_valid(): + """Test AnnotationSettingUpdatePayload with valid data.""" + payload = annotation_module.AnnotationSettingUpdatePayload( + score_threshold=0.75, + ) + assert payload.score_threshold == 0.75 + + +def test_annotation_list_query_defaults(): + """Test AnnotationListQuery with default parameters.""" + query = annotation_module.AnnotationListQuery() + assert query.page == 1 + assert query.limit == 20 + assert query.keyword == "" + + +def test_annotation_list_query_custom_page(): + """Test AnnotationListQuery with custom page.""" + query = annotation_module.AnnotationListQuery(page=3, limit=50) + assert query.page == 3 + assert query.limit == 50 + + +def test_annotation_list_query_with_keyword(): + """Test AnnotationListQuery with keyword.""" + query = annotation_module.AnnotationListQuery(keyword="test") + assert query.keyword == "test" + + +def test_create_annotation_payload_with_message_id(): + """Test CreateAnnotationPayload with message ID.""" + payload = annotation_module.CreateAnnotationPayload( + message_id="550e8400-e29b-41d4-a716-446655440000", + question="What is AI?", + ) + assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000" + assert payload.question == "What is AI?" + + +def test_create_annotation_payload_with_text(): + """Test CreateAnnotationPayload with text content.""" + payload = annotation_module.CreateAnnotationPayload( + question="What is ML?", + answer="Machine learning is...", + ) + assert payload.question == "What is ML?" + assert payload.answer == "Machine learning is..." + + +def test_update_annotation_payload(): + """Test UpdateAnnotationPayload.""" + payload = annotation_module.UpdateAnnotationPayload( + question="Updated question", + answer="Updated answer", + ) + assert payload.question == "Updated question" + assert payload.answer == "Updated answer" + + +def test_annotation_reply_status_query_enable(): + """Test AnnotationReplyStatusQuery with enable action.""" + query = annotation_module.AnnotationReplyStatusQuery(action="enable") + assert query.action == "enable" + + +def test_annotation_reply_status_query_disable(): + """Test AnnotationReplyStatusQuery with disable action.""" + query = annotation_module.AnnotationReplyStatusQuery(action="disable") + assert query.action == "disable" + + +def test_annotation_file_payload_valid(): + """Test AnnotationFilePayload with valid message ID.""" + payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000") + assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000" diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 06a7b98baf..9f1ff9b40f 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -13,6 +13,9 @@ from pandas.errors import ParserError from werkzeug.datastructures import FileStorage from configs import dify_config +from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit +from services.annotation_service import AppAnnotationService +from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task class TestAnnotationImportRateLimiting: @@ -33,8 +36,6 @@ class TestAnnotationImportRateLimiting: def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account): """Test that per-minute rate limit is enforced.""" - from controllers.console.wraps import annotation_import_rate_limit - # Simulate exceeding per-minute limit mock_redis.zcard.side_effect = [ dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check @@ -54,7 +55,6 @@ class TestAnnotationImportRateLimiting: def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account): """Test that per-hour rate limit is enforced.""" - from controllers.console.wraps import annotation_import_rate_limit # Simulate exceeding per-hour limit mock_redis.zcard.side_effect = [ @@ -74,7 +74,6 @@ class TestAnnotationImportRateLimiting: def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account): """Test that requests within limits are allowed.""" - from controllers.console.wraps import annotation_import_rate_limit # Simulate being under both limits mock_redis.zcard.return_value = 2 @@ -110,7 +109,6 @@ class TestAnnotationImportConcurrencyControl: def test_concurrency_limit_enforced(self, mock_redis, mock_current_account): """Test that concurrent task limit is enforced.""" - from controllers.console.wraps import annotation_import_concurrency_limit # Simulate max concurrent tasks already running mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT @@ -127,7 +125,6 @@ class TestAnnotationImportConcurrencyControl: def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account): """Test that requests within concurrency limits are allowed.""" - from controllers.console.wraps import annotation_import_concurrency_limit # Simulate being under concurrent task limit mock_redis.zcard.return_value = 1 @@ -142,7 +139,6 @@ class TestAnnotationImportConcurrencyControl: def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account): """Test that old/stale job entries are removed.""" - from controllers.console.wraps import annotation_import_concurrency_limit mock_redis.zcard.return_value = 0 @@ -203,7 +199,6 @@ class TestAnnotationImportServiceValidation: def test_max_records_limit_enforced(self, mock_app, mock_db_session): """Test that files with too many records are rejected.""" - from services.annotation_service import AppAnnotationService # Create CSV with too many records max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS @@ -229,7 +224,6 @@ class TestAnnotationImportServiceValidation: def test_min_records_limit_enforced(self, mock_app, mock_db_session): """Test that files with too few valid records are rejected.""" - from services.annotation_service import AppAnnotationService # Create CSV with only header (no data rows) csv_content = "question,answer\n" @@ -249,7 +243,6 @@ class TestAnnotationImportServiceValidation: def test_invalid_csv_format_handled(self, mock_app, mock_db_session): """Test that invalid CSV format is handled gracefully.""" - from services.annotation_service import AppAnnotationService # Any content is fine once we force ParserError csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' @@ -270,7 +263,6 @@ class TestAnnotationImportServiceValidation: def test_valid_import_succeeds(self, mock_app, mock_db_session): """Test that valid import request succeeds.""" - from services.annotation_service import AppAnnotationService # Create valid CSV csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n" @@ -300,18 +292,10 @@ class TestAnnotationImportServiceValidation: class TestAnnotationImportTaskOptimization: """Test optimizations in batch import task.""" - def test_task_has_timeout_configured(self): - """Test that task has proper timeout configuration.""" - from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task - - # Verify task configuration - assert hasattr(batch_import_annotations_task, "time_limit") - assert hasattr(batch_import_annotations_task, "soft_time_limit") - - # Check timeout values are reasonable - # Hard limit should be 6 minutes (360s) - # Soft limit should be 5 minutes (300s) - # Note: actual values depend on Celery configuration + def test_task_is_registered_with_queue(self): + """Test that task is registered with the correct queue.""" + assert hasattr(batch_import_annotations_task, "apply_async") + assert hasattr(batch_import_annotations_task, "delay") class TestConfigurationValues: diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py new file mode 100644 index 0000000000..074bbfab78 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -0,0 +1,585 @@ +""" +Additional tests to improve coverage for low-coverage modules in controllers/console/app. +Target: increase coverage for files with <75% coverage. +""" + +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.app import ( + annotation as annotation_module, +) +from controllers.console.app import ( + completion as completion_module, +) +from controllers.console.app import ( + message as message_module, +) +from controllers.console.app import ( + ops_trace as ops_trace_module, +) +from controllers.console.app import ( + site as site_module, +) +from controllers.console.app import ( + statistic as statistic_module, +) +from controllers.console.app import ( + workflow_app_log as workflow_app_log_module, +) +from controllers.console.app import ( + workflow_draft_variable as workflow_draft_variable_module, +) +from controllers.console.app import ( + workflow_statistic as workflow_statistic_module, +) +from controllers.console.app import ( + workflow_trigger as workflow_trigger_module, +) +from controllers.console.app import ( + wraps as wraps_module, +) +from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload +from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload +from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery +from controllers.console.app.site import AppSiteUpdatePayload +from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload +from controllers.console.app.workflow_app_log import WorkflowAppLogQuery +from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload +from controllers.console.app.workflow_statistic import WorkflowStatisticQuery +from controllers.console.app.workflow_trigger import Parser, ParserEnable + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _ConnContext: + def __init__(self, rows): + self._rows = rows + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _query, _args): + return self._rows + + +# ========== Completion Tests ========== +class TestCompletionEndpoints: + """Tests for completion API endpoints.""" + + def test_completion_create_payload(self): + """Test completion creation payload.""" + payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={}) + assert payload.inputs == {"prompt": "test"} + + def test_chat_message_payload_uuid_validation(self): + payload = ChatMessagePayload( + inputs={}, + model_config={}, + query="hi", + conversation_id=str(uuid.uuid4()), + parent_message_id=str(uuid.uuid4()), + ) + assert payload.query == "hi" + + def test_completion_api_success(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: {"text": "ok"}, + ) + monkeypatch.setattr( + completion_module.helper, + "compact_generate_response", + lambda response: {"result": response}, + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + resp = method(app_model=MagicMock(id="app-1")) + + assert resp == {"result": {"text": "ok"}} + + def test_completion_api_conversation_not_exists(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: (_ for _ in ()).throw( + completion_module.services.errors.conversation.ConversationNotExistsError() + ), + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + with pytest.raises(NotFound): + method(app_model=MagicMock(id="app-1")) + + def test_completion_api_provider_not_initialized(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")), + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + with pytest.raises(completion_module.ProviderNotInitializeError): + method(app_model=MagicMock(id="app-1")) + + def test_completion_api_quota_exceeded(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()), + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + with pytest.raises(completion_module.ProviderQuotaExceededError): + method(app_model=MagicMock(id="app-1")) + + +# ========== OpsTrace Tests ========== +class TestOpsTraceEndpoints: + """Tests for ops_trace endpoint.""" + + def test_ops_trace_query_basic(self): + """Test ops_trace query.""" + query = TraceProviderQuery(tracing_provider="langfuse") + assert query.tracing_provider == "langfuse" + + def test_ops_trace_config_payload(self): + payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"}) + assert payload.tracing_config["api_key"] == "k" + + def test_trace_app_config_get_empty(self, app, monkeypatch): + api = ops_trace_module.TraceAppConfigApi() + method = _unwrap(api.get) + + monkeypatch.setattr( + ops_trace_module.OpsService, + "get_tracing_app_config", + lambda **_kwargs: None, + ) + + with app.test_request_context("/?tracing_provider=langfuse"): + result = method(app_id="app-1") + + assert result == {"has_not_configured": True} + + def test_trace_app_config_post_invalid(self, app, monkeypatch): + api = ops_trace_module.TraceAppConfigApi() + method = _unwrap(api.post) + + monkeypatch.setattr( + ops_trace_module.OpsService, + "create_tracing_app_config", + lambda **_kwargs: {"error": True}, + ) + + with app.test_request_context( + "/", + json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}}, + ): + with pytest.raises(BadRequest): + method(app_id="app-1") + + def test_trace_app_config_delete_not_found(self, app, monkeypatch): + api = ops_trace_module.TraceAppConfigApi() + method = _unwrap(api.delete) + + monkeypatch.setattr( + ops_trace_module.OpsService, + "delete_tracing_app_config", + lambda **_kwargs: False, + ) + + with app.test_request_context("/?tracing_provider=langfuse"): + with pytest.raises(BadRequest): + method(app_id="app-1") + + +# ========== Site Tests ========== +class TestSiteEndpoints: + """Tests for site endpoint.""" + + def test_site_response_structure(self): + """Test site response structure.""" + payload = AppSiteUpdatePayload(title="My Site", description="Test site") + assert payload.title == "My Site" + + def test_site_default_language_validation(self): + payload = AppSiteUpdatePayload(default_language="en-US") + assert payload.default_language == "en-US" + + def test_app_site_update_post(self, app, monkeypatch): + api = site_module.AppSite() + method = _unwrap(api.post) + + site = MagicMock() + query = MagicMock() + query.where.return_value.first.return_value = site + monkeypatch.setattr( + site_module.db, + "session", + MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + ) + monkeypatch.setattr( + site_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") + + with app.test_request_context("/", json={"title": "My Site"}): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result is site + + def test_app_site_access_token_reset(self, app, monkeypatch): + api = site_module.AppSiteAccessTokenReset() + method = _unwrap(api.post) + + site = MagicMock() + query = MagicMock() + query.where.return_value.first.return_value = site + monkeypatch.setattr( + site_module.db, + "session", + MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + ) + monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") + monkeypatch.setattr( + site_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") + + with app.test_request_context("/"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result is site + + +# ========== Workflow Tests ========== +class TestWorkflowEndpoints: + """Tests for workflow endpoints.""" + + def test_workflow_copy_payload(self): + """Test workflow copy payload.""" + payload = SyncDraftWorkflowPayload(graph={}, features={}) + assert payload.graph == {} + + def test_workflow_mode_query(self): + """Test workflow mode query.""" + payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi") + assert payload.query == "hi" + + +# ========== Workflow App Log Tests ========== +class TestWorkflowAppLogEndpoints: + """Tests for workflow app log endpoints.""" + + def test_workflow_app_log_query(self): + """Test workflow app log query.""" + query = WorkflowAppLogQuery(keyword="test", page=1, limit=20) + assert query.keyword == "test" + + def test_workflow_app_log_query_detail_bool(self): + query = WorkflowAppLogQuery(detail="true") + assert query.detail is True + + def test_workflow_app_log_api_get(self, app, monkeypatch): + api = workflow_app_log_module.WorkflowAppLogApi() + method = _unwrap(api.get) + + monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock())) + + class DummySession: + def __enter__(self): + return "session" + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession()) + + def fake_get_paginate(self, **_kwargs): + return {"items": [], "total": 0} + + monkeypatch.setattr( + workflow_app_log_module.WorkflowAppService, + "get_paginate_workflow_app_logs", + fake_get_paginate, + ) + + with app.test_request_context("/?page=1&limit=20"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result == {"items": [], "total": 0} + + +# ========== Workflow Draft Variable Tests ========== +class TestWorkflowDraftVariableEndpoints: + """Tests for workflow draft variable endpoints.""" + + def test_workflow_variable_creation(self): + """Test workflow variable creation.""" + payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test") + assert payload.name == "var1" + + def test_workflow_variable_collection_get(self, app, monkeypatch): + api = workflow_draft_variable_module.WorkflowVariableCollectionApi() + method = _unwrap(api.get) + + monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock())) + + class DummySession: + def __enter__(self): + return "session" + + def __exit__(self, exc_type, exc, tb): + return False + + class DummyDraftService: + def __init__(self, session): + self.session = session + + def list_variables_without_values(self, **_kwargs): + return {"items": [], "total": 0} + + monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession()) + + class DummyWorkflowService: + def is_workflow_exist(self, *args, **kwargs): + return True + + monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService) + monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService) + + with app.test_request_context("/?page=1&limit=20"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result == {"items": [], "total": 0} + + +# ========== Workflow Statistic Tests ========== +class TestWorkflowStatisticEndpoints: + """Tests for workflow statistic endpoints.""" + + def test_workflow_statistic_time_range(self): + """Test workflow statistic time range query.""" + query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31") + assert query.start == "2024-01-01" + + def test_workflow_statistic_blank_to_none(self): + query = WorkflowStatisticQuery(start="", end="") + assert query.start is None + assert query.end is None + + def test_workflow_daily_runs_statistic(self, app, monkeypatch): + monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr( + workflow_statistic_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]), + ) + monkeypatch.setattr( + workflow_statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + workflow_statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: (None, None), + ) + + api = workflow_statistic_module.WorkflowDailyRunsStatistic() + method = _unwrap(api.get) + + with app.test_request_context("/"): + response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-01"}]} + + def test_workflow_daily_terminals_statistic(self, app, monkeypatch): + monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr( + workflow_statistic_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: SimpleNamespace( + get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}] + ), + ) + monkeypatch.setattr( + workflow_statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + workflow_statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: (None, None), + ) + + api = workflow_statistic_module.WorkflowDailyTerminalsStatistic() + method = _unwrap(api.get) + + with app.test_request_context("/"): + response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-02"}]} + + +# ========== Workflow Trigger Tests ========== +class TestWorkflowTriggerEndpoints: + """Tests for workflow trigger endpoints.""" + + def test_webhook_trigger_payload(self): + """Test webhook trigger payload.""" + payload = Parser(node_id="node-1") + assert payload.node_id == "node-1" + + enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True) + assert enable_payload.enable_trigger is True + + def test_webhook_trigger_api_get(self, app, monkeypatch): + api = workflow_trigger_module.WebhookTriggerApi() + method = _unwrap(api.get) + + monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock())) + + trigger = MagicMock() + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = trigger + + class DummySession: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession()) + + with app.test_request_context("/?node_id=node-1"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result is trigger + + +# ========== Wraps Tests ========== +class TestWrapsEndpoints: + """Tests for wraps utility functions.""" + + def test_get_app_model_context(self): + """Test get_app_model wrapper context.""" + # These are decorator functions, so we test their availability + assert hasattr(wraps_module, "get_app_model") + + +# ========== MCP Server Tests ========== +class TestMCPServerEndpoints: + """Tests for MCP server endpoints.""" + + def test_mcp_server_connection(self): + """Test MCP server connection.""" + payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"}) + assert payload.parameters["url"] == "http://localhost:3000" + + def test_mcp_server_update_payload(self): + payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active") + assert payload.status == "active" + + +# ========== Error Handling Tests ========== +class TestErrorHandling: + """Tests for error handling in various endpoints.""" + + def test_annotation_list_query_validation(self): + """Test annotation list query validation.""" + with pytest.raises(ValueError): + annotation_module.AnnotationListQuery(page=0) + + +# ========== Integration-like Tests ========== +class TestPayloadIntegration: + """Integration tests for payload handling.""" + + def test_multiple_payload_types(self): + """Test handling of multiple payload types.""" + payloads = [ + annotation_module.AnnotationReplyPayload( + score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small" + ), + message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"), + statistic_module.StatisticTimeRangeQuery(start="2024-01-01"), + ] + assert len(payloads) == 3 + assert all(p is not None for p in payloads) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 0000000000..91f58460ac --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +class _SessionContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return False + + +def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None: + monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session)) + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + +def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + +def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + _install_features(monkeypatch, enabled=True) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportConfirmApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.commit.assert_called_once() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + +def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportCheckDependenciesApi() + method = _unwrap(api.get) + + session = MagicMock() + _install_session(monkeypatch, session) + monkeypatch.setattr( + app_import_module.AppDslService, + "check_dependencies", + lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}), + ) + + with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"): + response, status = method(app_model=SimpleNamespace(id="app-1")) + + assert status == 200 + assert response["leaked_dependencies"] == [] diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py new file mode 100644 index 0000000000..021e9a0784 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import InternalServerError + +from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.audio_service import AudioService +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + ProviderNotSupportTextToSpeechLanageServiceError, + UnsupportedAudioTypeServiceError, +) + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def _file_data(): + return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav") + + +def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"}) + api = ChatMessageAudioApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): + response = handler(app_model=app_model) + + assert response == {"text": "ok"} + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (AppModelConfigBrokenError(), AppUnavailableError), + (NoAudioUploadedServiceError(), NoAudioUploadedError), + (AudioTooLargeServiceError("too big"), AudioTooLargeError), + (UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError), + (ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError), + (ProviderTokenNotInitError("token"), ProviderNotInitializeError), + (QuotaExceededError(), ProviderQuotaExceededError), + (ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError), + (InvokeError("invoke"), CompletionRequestError), + ], +) +def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc)) + api = ChatMessageAudioApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): + with pytest.raises(expected): + handler(app_model=app_model) + + +def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) + api = ChatMessageAudioApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): + with pytest.raises(InternalServerError): + handler(app_model=app_model) + + +def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) + + api = ChatMessageTextApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context( + "/console/api/apps/app/text-to-audio", + method="POST", + json={"text": "hello", "voice": "v"}, + ): + response = handler(app_model=app_model) + + assert response == {"audio": "ok"} + + +def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())) + + api = ChatMessageTextApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context( + "/console/api/apps/app/text-to-audio", + method="POST", + json={"text": "hello"}, + ): + with pytest.raises(ProviderQuotaExceededError): + handler(app_model=app_model) + + +def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) + + api = TextModesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(tenant_id="t1") + + with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"): + response = handler(app_model=app_model) + + assert response == ["voice-1"] + + +def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + AudioService, + "transcript_tts_voices", + lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()), + ) + + api = TextModesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(tenant_id="t1") + + with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"): + with pytest.raises(AppUnavailableError): + handler(app_model=app_model) + + +def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageAudioApi() + method = _unwrap(api.post) + + response_payload = {"text": "hello"} + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + response = method(app_model=app_model) + + assert response == response_payload + + +def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr( + AudioService, + "transcript_asr", + lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")), + ) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + with pytest.raises(AudioTooLargeError): + method(app_model=app_model) + + +def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello"}, + ): + response = method(app_model=app_model) + + assert response == {"audio": "ok"} + + +def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices", + method="GET", + query_string={"language": "en-US"}, + ): + response = method(app_model=app_model) + + assert response == ["voice-1"] + + +def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"}) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + # Should not raise, AudioService is mocked + response = method(app_model=app_model) + assert response == {"text": "test"} + + +def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello", "language": "en-US"}, + ): + response = method(app_model=app_model) + assert response == {"audio": "test"} + + +def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr( + AudioService, + "transcript_tts_voices", + lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}], + ) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices?language=en-US", + method="GET", + ): + response = method(app_model=app_model) + assert isinstance(response, list) diff --git a/api/tests/unit_tests/controllers/console/app/test_audio_api.py b/api/tests/unit_tests/controllers/console/app/test_audio_api.py new file mode 100644 index 0000000000..8b71837c29 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_audio_api.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest + +from controllers.console.app import audio as audio_module +from controllers.console.app.error import AudioTooLargeError +from services.errors.audio import AudioTooLargeServiceError + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageAudioApi() + method = _unwrap(api.post) + + response_payload = {"text": "hello"} + monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + response = method(app_model=app_model) + + assert response == response_payload + + +def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr( + audio_module.AudioService, + "transcript_asr", + lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")), + ) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + with pytest.raises(AudioTooLargeError): + method(app_model=app_model) + + +def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello"}, + ): + response = method(app_model=app_model) + + assert response == {"audio": "ok"} + + +def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices", + method="GET", + query_string={"language": "en-US"}, + ): + response = method(app_model=app_model) + + assert response == ["voice-1"] + + +def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"}) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + # Should not raise, AudioService is mocked + response = method(app_model=app_model) + assert response == {"text": "test"} + + +def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello", "language": "en-US"}, + ): + response = method(app_model=app_model) + assert response == {"audio": "test"} + + +def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr( + audio_module.AudioService, + "transcript_tts_voices", + lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}], + ) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices?language=en-US", + method="GET", + ): + response = method(app_model=app_model) + assert isinstance(response, list) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py new file mode 100644 index 0000000000..5db8e5c332 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.app import conversation as conversation_module +from models.model import AppMode +from services.errors.conversation import ConversationNotExistsError + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def _make_account(): + return SimpleNamespace(timezone="UTC", id="u1") + + +def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.CompletionConversationApi() + method = _unwrap(api.get) + + account = _make_account() + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) + monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) + + paginate_result = MagicMock() + monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) + + with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response is paginate_result + + +def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.CompletionConversationApi() + method = _unwrap(api.get) + + account = _make_account() + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) + monkeypatch.setattr( + conversation_module, + "parse_time_range", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")), + ) + + with app.test_request_context( + "/console/api/apps/app-1/completion-conversations", + method="GET", + query_string={"start": "bad"}, + ): + with pytest.raises(BadRequest): + method(app_model=SimpleNamespace(id="app-1")) + + +def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.ChatConversationApi() + method = _unwrap(api.get) + + account = _make_account() + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) + monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) + + paginate_result = MagicMock() + monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) + + with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) + + assert response is paginate_result + + +def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: + conversation = SimpleNamespace(id="c1", app_id="app-1") + + query = MagicMock() + query.where.return_value = query + query.first.return_value = conversation + + session = MagicMock() + session.query.return_value = query + + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) + monkeypatch.setattr(conversation_module.db, "session", session) + + result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1") + + assert result is conversation + session.execute.assert_called_once() + session.commit.assert_called_once() + session.refresh.assert_called_once_with(conversation) + + +def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + + session = MagicMock() + session.query.return_value = query + + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) + monkeypatch.setattr(conversation_module.db, "session", session) + + with pytest.raises(NotFound): + conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing") + + +def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.CompletionConversationDetailApi() + method = _unwrap(api.delete) + + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) + monkeypatch.setattr( + conversation_module.ConversationService, + "delete", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()), + ) + + with pytest.raises(NotFound): + method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py new file mode 100644 index 0000000000..f83bc18da3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from controllers.console.app import generator as generator_module +from controllers.console.app.error import ProviderNotInitializeError +from core.errors.error import ProviderTokenNotInitError + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def _model_config_payload(): + return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}} + + +def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow): + class _Service: + def get_draft_workflow(self, app_model): + return workflow + + monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service()) + + +def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.RuleGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []}) + + with app.test_request_context( + "/console/api/rule-generate", + method="POST", + json={"instruction": "do it", "model_config": _model_config_payload()}, + ): + response = method() + + assert response == {"rules": []} + + +def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.RuleCodeGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + def _raise(*_args, **_kwargs): + raise ProviderTokenNotInitError("missing token") + + monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise) + + with app.test_request_context( + "/console/api/rule-code-generate", + method="POST", + json={"instruction": "do it", "model_config": _model_config_payload()}, + ): + with pytest.raises(ProviderNotInitializeError): + method() + + +def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "app app-1 not found" + + +def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + app_model = SimpleNamespace(id="app-1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + _install_workflow_service(monkeypatch, workflow=None) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "workflow app-1 not found" + + +def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + app_model = SimpleNamespace(id="app-1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + workflow = SimpleNamespace(graph_dict={"nodes": []}) + _install_workflow_service(monkeypatch, workflow=workflow) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "node node-1 not found" + + +def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + app_model = SimpleNamespace(id="app-1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + workflow = SimpleNamespace( + graph_dict={ + "nodes": [ + {"id": "node-1", "data": {"type": "code"}}, + ] + } + ) + _install_workflow_service(monkeypatch, workflow=workflow) + monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"}) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response = method() + + assert response == {"code": "x"} + + +def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr( + generator_module.LLMGenerator, + "instruction_modify_legacy", + lambda **_kwargs: {"instruction": "ok"}, + ) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "", + "current": "old", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response = method() + + assert response == {"instruction": "ok"} + + +def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "", + "current": "", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "incompatible parameters" + + +def test_instruction_template_prompt(app) -> None: + api = generator_module.InstructionGenerationTemplateApi() + method = _unwrap(api.post) + + with app.test_request_context( + "/console/api/instruction-generate/template", + method="POST", + json={"type": "prompt"}, + ): + response = method() + + assert "data" in response + + +def test_instruction_template_invalid_type(app) -> None: + api = generator_module.InstructionGenerationTemplateApi() + method = _unwrap(api.post) + + with app.test_request_context( + "/console/api/instruction-generate/template", + method="POST", + json={"type": "unknown"}, + ): + with pytest.raises(ValueError): + method() diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py new file mode 100644 index 0000000000..a76e958829 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import pytest + +from controllers.console.app import message as message_module + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test valid ChatMessagesQuery with all fields.""" + query = message_module.ChatMessagesQuery( + conversation_id="550e8400-e29b-41d4-a716-446655440000", + first_id="550e8400-e29b-41d4-a716-446655440001", + limit=50, + ) + assert query.limit == 50 + + +def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test ChatMessagesQuery with defaults.""" + query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000") + assert query.first_id is None + assert query.limit == 20 + + +def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test ChatMessagesQuery converts empty first_id to None.""" + query = message_module.ChatMessagesQuery( + conversation_id="550e8400-e29b-41d4-a716-446655440000", + first_id="", + ) + assert query.first_id is None + + +def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageFeedbackPayload with like rating.""" + payload = message_module.MessageFeedbackPayload( + message_id="550e8400-e29b-41d4-a716-446655440000", + rating="like", + content="Good answer", + ) + assert payload.rating == "like" + assert payload.content == "Good answer" + + +def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageFeedbackPayload with dislike rating.""" + payload = message_module.MessageFeedbackPayload( + message_id="550e8400-e29b-41d4-a716-446655440000", + rating="dislike", + ) + assert payload.rating == "dislike" + + +def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageFeedbackPayload without rating.""" + payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000") + assert payload.rating is None + + +def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with default format.""" + query = message_module.FeedbackExportQuery() + assert query.format == "csv" + assert query.from_source is None + + +def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with JSON format.""" + query = message_module.FeedbackExportQuery(format="json") + assert query.format == "json" + + +def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as true string.""" + query = message_module.FeedbackExportQuery(has_comment="true") + assert query.has_comment is True + + +def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as false string.""" + query = message_module.FeedbackExportQuery(has_comment="false") + assert query.has_comment is False + + +def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as 1.""" + query = message_module.FeedbackExportQuery(has_comment="1") + assert query.has_comment is True + + +def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as 0.""" + query = message_module.FeedbackExportQuery(has_comment="0") + assert query.has_comment is False + + +def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with rating filter.""" + query = message_module.FeedbackExportQuery(rating="like") + assert query.rating == "like" + + +def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test AnnotationCountResponse creation.""" + response = message_module.AnnotationCountResponse(count=10) + assert response.count == 10 + + +def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test SuggestedQuestionsResponse creation.""" + response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) + assert len(response.data) == 2 + assert response.data[0] == "What is AI?" diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py new file mode 100644 index 0000000000..61d92bb5c7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import model_config as model_config_module +from models.model import AppMode, AppModelConfig + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = model_config_module.ModelConfigResource() + method = _unwrap(api.post) + + app_model = SimpleNamespace( + id="app-1", + mode=AppMode.CHAT.value, + is_agent=False, + app_model_config_id=None, + updated_by=None, + updated_at=None, + ) + monkeypatch.setattr( + model_config_module.AppModelConfigService, + "validate_configuration", + lambda **_kwargs: {"pre_prompt": "hi"}, + ) + monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + session = MagicMock() + monkeypatch.setattr(model_config_module.db, "session", session) + + def _from_model_config_dict(self, model_config): + self.pre_prompt = model_config["pre_prompt"] + self.id = "config-1" + return self + + monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict) + send_mock = MagicMock() + monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) + + with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): + response = method(app_model=app_model) + + session.add.assert_called_once() + session.flush.assert_called_once() + session.commit.assert_called_once() + send_mock.assert_called_once() + assert app_model.app_model_config_id == "config-1" + assert response["result"] == "success" + + +def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = model_config_module.ModelConfigResource() + method = _unwrap(api.post) + + app_model = SimpleNamespace( + id="app-1", + mode=AppMode.AGENT_CHAT.value, + is_agent=True, + app_model_config_id="config-0", + updated_by=None, + updated_at=None, + ) + + original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1") + original_config.agent_mode = json.dumps( + { + "enabled": True, + "strategy": "function-calling", + "tools": [ + { + "provider_id": "provider", + "provider_type": "builtin", + "tool_name": "tool", + "tool_parameters": {"secret": "masked"}, + } + ], + "prompt": None, + } + ) + + session = MagicMock() + query = MagicMock() + query.where.return_value = query + query.first.return_value = original_config + session.query.return_value = query + monkeypatch.setattr(model_config_module.db, "session", session) + + monkeypatch.setattr( + model_config_module.AppModelConfigService, + "validate_configuration", + lambda **_kwargs: { + "pre_prompt": "hi", + "agent_mode": { + "enabled": True, + "strategy": "function-calling", + "tools": [ + { + "provider_id": "provider", + "provider_type": "builtin", + "tool_name": "tool", + "tool_parameters": {"secret": "masked"}, + } + ], + "prompt": None, + }, + }, + ) + monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object()) + + class _ParamManager: + def __init__(self, **_kwargs): + self.delete_called = False + + def decrypt_tool_parameters(self, _value): + return {"secret": "decrypted"} + + def mask_tool_parameters(self, _value): + return {"secret": "masked"} + + def encrypt_tool_parameters(self, _value): + return {"secret": "encrypted"} + + def delete_tool_parameters_cache(self): + self.delete_called = True + + monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager) + send_mock = MagicMock() + monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) + + with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): + response = method(app_model=app_model) + + stored_config = session.add.call_args[0][0] + stored_agent_mode = json.loads(stored_config.agent_mode) + assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted" + assert response["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py new file mode 100644 index 0000000000..15459994f9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from decimal import Decimal +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import BadRequest + +from controllers.console.app import statistic as statistic_module + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _ConnContext: + def __init__(self, rows): + self._rows = rows + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _query, _args): + return self._rows + + +def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None: + engine = SimpleNamespace(begin=lambda: _ConnContext(rows)) + monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine)) + + +def _install_common(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: (None, None), + ) + monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) + + +def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-01", message_count=3)] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]} + + +def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyConversationStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} + + +def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyTokenCostStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + data = response.get_json() + assert len(data["data"]) == 1 + assert data["data"][0]["date"] == "2024-01-03" + assert data["data"][0]["token_count"] == 10 + assert data["data"][0]["total_price"] == 0.25 + + +def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyTerminalsStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]} + + +def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that AverageSessionInteractionStatistic is limited to chat/agent modes.""" + # This just verifies the decorator is applied correctly + # Actual endpoint testing would require complex JOIN mocking + api = statistic_module.AverageSessionInteractionStatistic() + method = _unwrap(api.get) + assert callable(method) + + +def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + def mock_parse(*args, **kwargs): + raise ValueError("Invalid time range") + + _install_db(monkeypatch, []) + monkeypatch.setattr( + statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse) + monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + with pytest.raises(BadRequest): + method(app_model=SimpleNamespace(id="app-1")) + + +def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + rows = [ + SimpleNamespace(date="2024-01-01", message_count=10), + SimpleNamespace(date="2024-01-02", message_count=15), + SimpleNamespace(date="2024-01-03", message_count=12), + ] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + data = response.get_json() + assert len(data["data"]) == 3 + + +def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + _install_common(monkeypatch) + _install_db(monkeypatch, []) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": []} + + +def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyConversationStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)] + _install_db(monkeypatch, rows) + monkeypatch.setattr( + statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: ("s", "e"), + ) + monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} + + +def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyTokenCostStatistic() + method = _unwrap(api.get) + + rows = [ + SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"), + SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"), + ] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + data = response.get_json() + assert len(data["data"]) == 2 diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py new file mode 100644 index 0000000000..f100080eaa --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from werkzeug.exceptions import HTTPException, NotFound + +from controllers.console.app import workflow as workflow_module +from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None) + workflow = SimpleNamespace(features_dict={}, tenant_id="t1") + + assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == [] + + +def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: + config = object() + file_list = [ + File( + tenant_id="t1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="http://u", + ) + ] + build_mock = Mock(return_value=file_list) + monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config) + monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock) + + workflow = SimpleNamespace(features_dict={}, tenant_id="t1") + result = workflow_module._parse_file(workflow, files=[{"id": "f"}]) + + assert result == file_list + build_mock.assert_called_once() + + +def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"): + with pytest.raises(HTTPException) as exc: + handler(api, app_model=SimpleNamespace(id="app")) + + assert exc.value.code == 415 + + +def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + with app.test_request_context( + "/apps/app/workflows/draft", + method="POST", + data="[]", + content_type="application/json", + ): + response, status = handler(api, app_model=SimpleNamespace(id="app")) + + assert status == 400 + assert response["message"] == "Invalid JSON data" + + +def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow = SimpleNamespace( + unique_hash="h", + updated_at=None, + created_at=datetime(2024, 1, 1), + ) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + monkeypatch.setattr( + workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env" + ) + monkeypatch.setattr( + workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv" + ) + + service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow) + monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service) + + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft", + method="POST", + json={"graph": {}, "features": {}, "hash": "h"}, + ): + response = handler(api, app_model=SimpleNamespace(id="app")) + + assert response["result"] == "success" + + +def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + def _raise(*_args, **_kwargs): + raise workflow_module.WorkflowHashNotEqualError() + + service = SimpleNamespace(sync_draft_workflow=_raise) + monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service) + + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft", + method="POST", + json={"graph": {}, "features": {}, "hash": "h"}, + ): + with pytest.raises(DraftWorkflowNotSync): + handler(api, app_model=SimpleNamespace(id="app")) + + +def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) + ) + + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.get) + + with pytest.raises(DraftWorkflowNotExist): + handler(api, app_model=SimpleNamespace(id="app")) + + +def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + workflow_module.AppGenerateService, + "generate", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + workflow_module.services.errors.conversation.ConversationNotExistsError() + ), + ) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + api = workflow_module.AdvancedChatDraftWorkflowRunApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/advanced-chat/workflows/draft/run", + method="POST", + json={"inputs": {}}, + ): + with pytest.raises(NotFound): + handler(api, app_model=SimpleNamespace(id="app")) diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py new file mode 100644 index 0000000000..7664e492da --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from controllers.console.app import wraps as wraps_module +from controllers.console.app.error import AppNotFoundError +from models.model import AppMode + + +def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: + app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + + monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + @wraps_module.get_app_model + def handler(app_model): + return app_model.id + + assert handler(app_id="app-1") == "app-1" + + +def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: + app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + + monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) + def handler(app_model): + return app_model.id + + with pytest.raises(AppNotFoundError): + handler(app_id="app-1") + + +def test_get_app_model_requires_app_id() -> None: + @wraps_module.get_app_model + def handler(app_model): + return app_model.id + + with pytest.raises(ValueError): + handler() diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/__init__.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py new file mode 100644 index 0000000000..9014edc39e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -0,0 +1,817 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.datasource_auth import ( + DatasourceAuth, + DatasourceAuthDefaultApi, + DatasourceAuthDeleteApi, + DatasourceAuthListApi, + DatasourceAuthOauthCustomClient, + DatasourceAuthUpdateApi, + DatasourceHardCodeAuthListApi, + DatasourceOAuthCallback, + DatasourcePluginOAuthAuthorizationUrl, + DatasourceUpdateProviderNameApi, +) +from core.plugin.impl.oauth import OAuthHandler +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from services.datasource_provider_service import DatasourceProviderService +from services.plugin.oauth_service import OAuthProxyService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDatasourcePluginOAuthAuthorizationUrl: + def test_get_success(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + user = MagicMock(id="user-1") + + with ( + app.test_request_context("/?credential_id=cred-1"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthProxyService, + "create_proxy_context", + return_value="ctx-1", + ), + patch.object( + OAuthHandler, + "get_authorization_url", + return_value={"url": "http://auth"}, + ), + ): + response = method(api, "notion") + + assert response.status_code == 200 + + def test_get_no_oauth_config(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_without_credential_id_sets_cookie(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + user = MagicMock(id="user-1") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthProxyService, + "create_proxy_context", + return_value="ctx-123", + ), + patch.object( + OAuthHandler, + "get_authorization_url", + return_value={"url": "http://auth"}, + ), + ): + response = method(api, "notion") + + assert response.status_code == 200 + assert "context_id" in response.headers.get("Set-Cookie") + + +class TestDatasourceOAuthCallback: + def test_callback_success_new_credential(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {"name": "test"} + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": None, + } + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "add_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + + def test_callback_missing_context(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "notion") + + def test_callback_invalid_context(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + with ( + app.test_request_context("/?context_id=bad"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "notion") + + def test_callback_oauth_config_not_found(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + context = {"user_id": "u", "tenant_id": "t"} + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "notion") + + def test_callback_reauthorize_existing_credential(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {} # avatar + name missing + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": "cred-1", + } + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "reauthorize_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + assert "/oauth-callback" in response.location + + def test_callback_context_id_from_cookie(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {} + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": None, + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "add_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + + +class TestDatasourceAuth: + def test_post_success(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {"credentials": {"key": "val"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "add_datasource_api_key_provider", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_post_invalid_credentials(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {"credentials": {"key": "bad"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "add_datasource_api_key_provider", + side_effect=CredentialsValidateFailedError("invalid"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_success(self, app): + api = DatasourceAuth() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "list_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api, "notion") + + assert status == 200 + assert response["result"] + + def test_post_missing_credentials(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_empty_list(self, app): + api = DatasourceAuth() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "list_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api, "notion") + + assert status == 200 + assert response["result"] == [] + + +class TestDatasourceAuthDeleteApi: + def test_delete_success(self, app): + api = DatasourceAuthDeleteApi() + method = unwrap(api.post) + + payload = {"credential_id": "cred-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "remove_datasource_credentials", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_delete_missing_credential_id(self, app): + api = DatasourceAuthDeleteApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + +class TestDatasourceAuthUpdateApi: + def test_update_success(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": {"k": "v"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 201 + + def test_update_with_credentials_none(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": None} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ) as update_mock, + ): + response, status = method(api, "notion") + + update_mock.assert_called_once() + assert status == 201 + + def test_update_name_only(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "name": "New Name"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ), + ): + _, status = method(api, "notion") + + assert status == 201 + + def test_update_with_empty_credentials_dict(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": {}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ) as update_mock, + ): + _, status = method(api, "notion") + + update_mock.assert_called_once() + assert status == 201 + + +class TestDatasourceAuthListApi: + def test_list_success(self, app): + api = DatasourceAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_all_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api) + + assert status == 200 + + def test_auth_list_empty(self, app): + api = DatasourceAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_all_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api) + + assert status == 200 + assert response["result"] == [] + + def test_hardcode_list_empty(self, app): + api = DatasourceHardCodeAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_hard_code_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api) + + assert status == 200 + assert response["result"] == [] + + +class TestDatasourceHardCodeAuthListApi: + def test_list_success(self, app): + api = DatasourceHardCodeAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_hard_code_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api) + + assert status == 200 + + +class TestDatasourceAuthOauthCustomClient: + def test_post_success(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = {"client_params": {}, "enable_oauth_custom_client": True} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_delete_success(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "remove_oauth_custom_client_params", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_post_empty_payload(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ), + ): + _, status = method(api, "notion") + + assert status == 200 + + def test_post_disabled_flag(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = { + "client_params": {"a": 1}, + "enable_oauth_custom_client": False, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ) as setup_mock, + ): + _, status = method(api, "notion") + + setup_mock.assert_called_once() + assert status == 200 + + +class TestDatasourceAuthDefaultApi: + def test_set_default_success(self, app): + api = DatasourceAuthDefaultApi() + method = unwrap(api.post) + + payload = {"id": "cred-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "set_default_datasource_provider", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_default_missing_id(self, app): + api = DatasourceAuthDefaultApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + +class TestDatasourceUpdateProviderNameApi: + def test_update_name_success(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "name": "New Name"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_provider_name", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_update_name_too_long(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = { + "credential_id": "id", + "name": "x" * 101, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_update_name_missing_credential_id(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = {"name": "Valid"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py new file mode 100644 index 0000000000..7a8ccde55a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py @@ -0,0 +1,143 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.datasource_content_preview import ( + DataSourceContentPreviewApi, +) +from models import Account +from models.dataset import Pipeline + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDataSourceContentPreviewApi: + def _valid_payload(self): + return { + "inputs": {"query": "hello"}, + "datasource_type": "notion", + "credential_id": "cred-1", + } + + def test_post_success(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = self._valid_payload() + + pipeline = MagicMock(spec=Pipeline) + node_id = "node-1" + account = MagicMock(spec=Account) + + preview_result = {"content": "preview data"} + + service_instance = MagicMock() + service_instance.run_datasource_node_preview.return_value = preview_result + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", + return_value=service_instance, + ), + ): + response, status = method(api, pipeline, node_id) + + service_instance.run_datasource_node_preview.assert_called_once_with( + pipeline=pipeline, + node_id=node_id, + user_inputs=payload["inputs"], + account=account, + datasource_type=payload["datasource_type"], + is_published=True, + credential_id=payload["credential_id"], + ) + assert status == 200 + assert response == preview_result + + def test_post_forbidden_non_account_user(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = self._valid_payload() + + pipeline = MagicMock(spec=Pipeline) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + MagicMock(), # NOT Account + ), + ): + with pytest.raises(Forbidden): + method(api, pipeline, "node-1") + + def test_post_invalid_payload(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = { + "inputs": {"query": "hello"}, + # datasource_type missing + } + + pipeline = MagicMock(spec=Pipeline) + account = MagicMock(spec=Account) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + ): + with pytest.raises(ValueError): + method(api, pipeline, "node-1") + + def test_post_without_credential_id(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = { + "inputs": {"query": "hello"}, + "datasource_type": "notion", + "credential_id": None, + } + + pipeline = MagicMock(spec=Pipeline) + account = MagicMock(spec=Account) + + service_instance = MagicMock() + service_instance.run_datasource_node_preview.return_value = {"ok": True} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", + return_value=service_instance, + ), + ): + response, status = method(api, pipeline, "node-1") + + service_instance.run_datasource_node_preview.assert_called_once() + assert status == 200 + assert response == {"ok": True} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py new file mode 100644 index 0000000000..3b8679f4ec --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -0,0 +1,187 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.rag_pipeline import ( + CustomizedPipelineTemplateApi, + PipelineTemplateDetailApi, + PipelineTemplateListApi, + PublishCustomizedPipelineTemplateApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestPipelineTemplateListApi: + def test_get_success(self, app): + api = PipelineTemplateListApi() + method = unwrap(api.get) + + templates = [{"id": "t1"}] + + with ( + app.test_request_context("/?type=built-in&language=en-US"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates", + return_value=templates, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == templates + + +class TestPipelineTemplateDetailApi: + def test_get_success(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + template = {"id": "tpl-1"} + + service = MagicMock() + service.get_pipeline_template_detail.return_value = template + + with ( + app.test_request_context("/?type=built-in"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "tpl-1") + + assert status == 200 + assert response == template + + +class TestCustomizedPipelineTemplateApi: + def test_patch_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.patch) + + payload = { + "name": "Template", + "description": "Desc", + "icon_info": {"icon": "📘"}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template" + ) as update_mock, + ): + response = method(api, "tpl-1") + + update_mock.assert_called_once() + assert response == 200 + + def test_delete_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template" + ) as delete_mock, + ): + response = method(api, "tpl-1") + + delete_mock.assert_called_once_with("tpl-1") + assert response == 200 + + def test_post_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.post) + + template = MagicMock() + template.yaml_content = "yaml-data" + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = template + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", + return_value=session_ctx, + ), + ): + response, status = method(api, "tpl-1") + + assert status == 200 + assert response == {"data": "yaml-data"} + + def test_post_template_not_found(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.post) + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", + return_value=session_ctx, + ), + ): + with pytest.raises(ValueError): + method(api, "tpl-1") + + +class TestPublishCustomizedPipelineTemplateApi: + def test_post_success(self, app): + api = PublishCustomizedPipelineTemplateApi() + method = unwrap(api.post) + + payload = { + "name": "Template", + "description": "Desc", + "icon_info": {"icon": "📘"}, + } + + service = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response = method(api, "pipeline-1") + + service.publish_customized_pipeline_template.assert_called_once() + assert response == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py new file mode 100644 index 0000000000..fd38fcbb5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -0,0 +1,187 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +import services +from controllers.console import console_ns +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import ( + CreateEmptyRagPipelineDatasetApi, + CreateRagPipelineDatasetApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestCreateRagPipelineDatasetApi: + def _valid_payload(self): + return {"yaml_content": "name: test"} + + def test_post_success(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=True) + import_info = {"dataset_id": "ds-1"} + + mock_service = MagicMock() + mock_service.create_rag_pipeline_dataset.return_value = import_info + + mock_session_ctx = MagicMock() + mock_session_ctx.__enter__.return_value = MagicMock() + mock_session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", + return_value=mock_session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", + return_value=mock_service, + ), + ): + response, status = method(api) + + assert status == 201 + assert response == import_info + + def test_post_forbidden_non_editor(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_dataset_name_duplicate(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=True) + + mock_service = MagicMock() + mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError() + + mock_session_ctx = MagicMock() + mock_session_ctx.__enter__.return_value = MagicMock() + mock_session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", + return_value=mock_session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", + return_value=mock_service, + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_post_invalid_payload(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = {} + user = MagicMock(is_dataset_editor=True) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestCreateEmptyRagPipelineDatasetApi: + def test_post_success(self, app): + api = CreateEmptyRagPipelineDatasetApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal", + return_value={"id": "ds-1"}, + ), + ): + response, status = method(api) + + assert status == 201 + assert response == {"id": "ds-1"} + + def test_post_forbidden_non_editor(self, app): + api = CreateEmptyRagPipelineDatasetApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py new file mode 100644 index 0000000000..b4c0903f63 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -0,0 +1,324 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Response + +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import ( + RagPipelineEnvironmentVariableCollectionApi, + RagPipelineNodeVariableCollectionApi, + RagPipelineSystemVariableCollectionApi, + RagPipelineVariableApi, + RagPipelineVariableCollectionApi, + RagPipelineVariableResetApi, +) +from controllers.web.error import InvalidArgumentError, NotFoundError +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.types import SegmentType +from models.account import Account + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def fake_db(): + db = MagicMock() + db.engine = MagicMock() + db.session.return_value = MagicMock() + return db + + +@pytest.fixture +def editor_user(): + user = MagicMock(spec=Account) + user.has_edit_permission = True + return user + + +@pytest.fixture +def restx_config(app): + return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"}) + + +class TestRagPipelineVariableCollectionApi: + def test_get_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + rag_srv = MagicMock() + rag_srv.is_workflow_exist.return_value = True + + # IMPORTANT: RESTX expects .variables + var_list = MagicMock() + var_list.variables = [] + + draft_srv = MagicMock() + draft_srv.list_variables_without_values.return_value = var_list + + with ( + app.test_request_context("/?page=1&limit=10"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=draft_srv, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [] + + def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + rag_srv = MagicMock() + rag_srv.is_workflow_exist.return_value = False + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + ): + with pytest.raises(DraftWorkflowNotExist): + method(api, pipeline) + + def test_delete_variables_success(self, app, fake_db, editor_user): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.delete) + + pipeline = MagicMock(id="p1") + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"), + ): + result = method(api, pipeline) + + assert isinstance(result, Response) + assert result.status_code == 204 + + +class TestRagPipelineNodeVariableCollectionApi: + def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineNodeVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + var_list = MagicMock() + var_list.variables = [] + + srv = MagicMock() + srv.list_node_variables.return_value = var_list + + with ( + app.test_request_context("/"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline, "node1") + + assert result["items"] == [] + + def test_get_node_variables_invalid_node(self, app, editor_user): + api = RagPipelineNodeVariableCollectionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + ): + with pytest.raises(InvalidArgumentError): + method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID) + + +class TestRagPipelineVariableApi: + def test_get_variable_not_found(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.get) + + srv = MagicMock() + srv.get_variable.return_value = None + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + with pytest.raises(NotFoundError): + method(api, MagicMock(), "v1") + + def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.patch) + + pipeline = MagicMock(id="p1", tenant_id="t1") + variable = MagicMock(app_id="p1", value_type=SegmentType.FILE) + + srv = MagicMock() + srv.get_variable.return_value = variable + + payload = {"value": "invalid"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + with pytest.raises(InvalidArgumentError): + method(api, pipeline, "v1") + + def test_delete_variable_success(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.delete) + + pipeline = MagicMock(id="p1") + variable = MagicMock(app_id="p1") + + srv = MagicMock() + srv.get_variable.return_value = variable + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline, "v1") + + assert result.status_code == 204 + + +class TestRagPipelineVariableResetApi: + def test_reset_variable_success(self, app, fake_db, editor_user): + api = RagPipelineVariableResetApi() + method = unwrap(api.put) + + pipeline = MagicMock(id="p1") + workflow = MagicMock() + variable = MagicMock(app_id="p1") + + srv = MagicMock() + srv.get_variable.return_value = variable + srv.reset_variable.return_value = variable + + rag_srv = MagicMock() + rag_srv.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal", + return_value={"id": "v1"}, + ), + ): + result = method(api, pipeline, "v1") + + assert result == {"id": "v1"} + + +class TestSystemAndEnvironmentVariablesApi: + def test_system_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineSystemVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + var_list = MagicMock() + var_list.variables = [] + + srv = MagicMock() + srv.list_system_variables.return_value = var_list + + with ( + app.test_request_context("/"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [] + + def test_environment_variables_success(self, app, editor_user): + api = RagPipelineEnvironmentVariableCollectionApi() + method = unwrap(api.get) + + env_var = MagicMock( + id="e1", + name="ENV", + description="d", + selector="s", + value_type=MagicMock(value="string"), + value="x", + ) + + workflow = MagicMock(environment_variables=[env_var]) + pipeline = MagicMock(id="p1") + + rag_srv = MagicMock() + rag_srv.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + ): + result = method(api, pipeline) + + assert len(result["items"]) == 1 diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py new file mode 100644 index 0000000000..a72ad45110 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -0,0 +1,329 @@ +from unittest.mock import MagicMock, patch + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( + RagPipelineExportApi, + RagPipelineImportApi, + RagPipelineImportCheckDependenciesApi, + RagPipelineImportConfirmApi, +) +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestRagPipelineImportApi: + def _payload(self, mode="create"): + return { + "mode": mode, + "yaml_content": "content", + "name": "Test", + } + + def test_post_success_200(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = "completed" + result.model_dump.return_value = {"status": "success"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == {"status": "success"} + + def test_post_failed_400(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.FAILED + result.model_dump.return_value = {"status": "failed"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 400 + assert response == {"status": "failed"} + + def test_post_pending_202(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.PENDING + result.model_dump.return_value = {"status": "pending"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 202 + assert response == {"status": "pending"} + + +class TestRagPipelineImportConfirmApi: + def test_confirm_success(self, app): + api = RagPipelineImportConfirmApi() + method = unwrap(api.post) + + user = MagicMock() + result = MagicMock() + result.status = "completed" + result.model_dump.return_value = {"ok": True} + + service = MagicMock() + service.confirm_import.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, "import-1") + + assert status == 200 + assert response == {"ok": True} + + def test_confirm_failed(self, app): + api = RagPipelineImportConfirmApi() + method = unwrap(api.post) + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.FAILED + result.model_dump.return_value = {"ok": False} + + service = MagicMock() + service.confirm_import.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, "import-1") + + assert status == 400 + assert response == {"ok": False} + + +class TestRagPipelineImportCheckDependenciesApi: + def test_get_success(self, app): + api = RagPipelineImportCheckDependenciesApi() + method = unwrap(api.get) + + pipeline = MagicMock(spec=Pipeline) + result = MagicMock() + result.model_dump.return_value = {"deps": []} + + service = MagicMock() + service.check_dependencies.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, pipeline) + + assert status == 200 + assert response == {"deps": []} + + +class TestRagPipelineExportApi: + def test_get_with_include_secret(self, app): + api = RagPipelineExportApi() + method = unwrap(api.get) + + pipeline = MagicMock(spec=Pipeline) + service = MagicMock() + service.export_rag_pipeline_dsl.return_value = {"yaml": "data"} + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/?include_secret=true"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, pipeline) + + assert status == 200 + assert response == {"data": {"yaml": "data"}} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py new file mode 100644 index 0000000000..7775cbdd81 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -0,0 +1,688 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( + DefaultRagPipelineBlockConfigApi, + DraftRagPipelineApi, + DraftRagPipelineRunApi, + PublishedAllRagPipelineApi, + PublishedRagPipelineApi, + PublishedRagPipelineRunApi, + RagPipelineByIdApi, + RagPipelineDatasourceVariableApi, + RagPipelineDraftNodeRunApi, + RagPipelineDraftRunIterationNodeApi, + RagPipelineDraftRunLoopNodeApi, + RagPipelineRecommendedPluginApi, + RagPipelineTaskStopApi, + RagPipelineTransformApi, + RagPipelineWorkflowLastRunApi, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from services.errors.app import WorkflowHashNotEqualError +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDraftWorkflowApi: + def test_get_draft_success(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + workflow = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + assert result == workflow + + def test_get_draft_not_exist(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + service = MagicMock() + service.get_draft_workflow.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(DraftWorkflowNotExist): + method(api, pipeline) + + def test_sync_hash_not_match(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + service = MagicMock() + service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError() + + with ( + app.test_request_context("/", json={"graph": {}, "features": {}}), + patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(DraftWorkflowNotSync): + method(api, pipeline) + + def test_sync_invalid_text_plain(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + response, status = method(api, pipeline) + assert status == 400 + + +class TestDraftRunNodes: + def test_iteration_node_success(self, app): + api = RagPipelineDraftRunIterationNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + result = method(api, pipeline, "node") + assert result == {"ok": True} + + def test_iteration_node_conversation_not_exists(self, app): + api = RagPipelineDraftRunIterationNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", + side_effect=services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "node") + + def test_loop_node_success(self, app): + api = RagPipelineDraftRunLoopNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + assert method(api, pipeline, "node") == {"ok": True} + + +class TestPipelineRunApis: + def test_draft_run_success(self, app): + api = DraftRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + assert method(api, pipeline) == {"ok": True} + + def test_draft_run_rate_limit(self, app): + api = DraftRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context( + "/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"} + ), + patch.object( + type(console_ns), + "payload", + {"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, pipeline) + + +class TestDraftNodeRun: + def test_execution_not_found(self, app): + api = RagPipelineDraftNodeRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + service = MagicMock() + service.run_draft_workflow_node.return_value = None + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(ValueError): + method(api, pipeline, "node") + + +class TestPublishedPipelineApis: + def test_publish_success(self, app): + api = PublishedRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + workflow = MagicMock( + id="w1", + created_at=datetime.utcnow(), + ) + + session = MagicMock() + session.merge.return_value = pipeline + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + service = MagicMock() + service.publish_workflow.return_value = workflow + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + + assert result["result"] == "success" + assert "created_at" in result + + +class TestMiscApis: + def test_task_stop(self, app): + api = RagPipelineTaskStopApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag" + ) as stop_mock, + ): + result = method(api, pipeline, "task-1") + stop_mock.assert_called_once() + assert result["result"] == "success" + + def test_transform_forbidden(self, app): + api = RagPipelineTransformApi() + method = unwrap(api.post) + + user = MagicMock(has_edit_permission=False, is_dataset_operator=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds1") + + def test_recommended_plugins(self, app): + api = RagPipelineRecommendedPluginApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_recommended_plugins.return_value = [{"id": "p1"}] + + with ( + app.test_request_context("/?type=all"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api) + assert result == [{"id": "p1"}] + + +class TestPublishedRagPipelineRunApi: + def test_published_run_success(self, app): + api = PublishedRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + "response_mode": "blocking", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + result = method(api, pipeline) + assert result == {"ok": True} + + def test_published_run_rate_limit(self, app): + api = PublishedRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, pipeline) + + +class TestDefaultBlockConfigApi: + def test_get_block_config_success(self, app): + api = DefaultRagPipelineBlockConfigApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + service = MagicMock() + service.get_default_block_config.return_value = {"k": "v"} + + with ( + app.test_request_context("/?q={}"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "llm") + assert result == {"k": "v"} + + def test_get_block_config_invalid_json(self, app): + api = DefaultRagPipelineBlockConfigApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + with app.test_request_context("/?q=bad-json"): + with pytest.raises(ValueError): + method(api, pipeline, "llm") + + +class TestPublishedAllRagPipelineApi: + def test_get_published_workflows_success(self, app): + api = PublishedAllRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + service = MagicMock() + service.get_all_published_workflow.return_value = ([{"id": "w1"}], False) + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [{"id": "w1"}] + assert result["has_more"] is False + + def test_get_published_workflows_forbidden(self, app): + api = PublishedAllRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + with ( + app.test_request_context("/?user_id=u2"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + with pytest.raises(Forbidden): + method(api, pipeline) + + +class TestRagPipelineByIdApi: + def test_patch_success(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.patch) + + pipeline = MagicMock(tenant_id="t1") + user = MagicMock(id="u1") + + workflow = MagicMock() + + service = MagicMock() + service.update_workflow.return_value = workflow + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + payload = {"marked_name": "test"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "w1") + + assert result == workflow + + def test_patch_no_fields(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.patch) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={}), + patch.object(type(console_ns), "payload", {}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + result, status = method(api, pipeline, "w1") + assert status == 400 + + +class TestRagPipelineWorkflowLastRunApi: + def test_last_run_success(self, app): + api = RagPipelineWorkflowLastRunApi() + method = unwrap(api.get) + + pipeline = MagicMock() + workflow = MagicMock() + node_exec = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = workflow + service.get_node_last_run.return_value = node_exec + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "node1") + assert result == node_exec + + def test_last_run_not_found(self, app): + api = RagPipelineWorkflowLastRunApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "node1") + + +class TestRagPipelineDatasourceVariableApi: + def test_set_datasource_variables_success(self, app): + api = RagPipelineDatasourceVariableApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "datasource_type": "db", + "datasource_info": {}, + "start_node_id": "n1", + "start_node_title": "Node", + } + + service = MagicMock() + service.set_datasource_variables.return_value = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + assert result is not None diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py new file mode 100644 index 0000000000..3060062adf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -0,0 +1,444 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.console.datasets import data_source +from controllers.console.datasets.data_source import ( + DataSourceApi, + DataSourceNotionApi, + DataSourceNotionDatasetSyncApi, + DataSourceNotionDocumentSyncApi, + DataSourceNotionListApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_ctx(): + return (MagicMock(id="u1"), "tenant-1") + + +@pytest.fixture +def patch_tenant(tenant_ctx): + with patch( + "controllers.console.datasets.data_source.current_account_with_tenant", + return_value=tenant_ctx, + ): + yield + + +@pytest.fixture +def mock_engine(): + with patch.object( + type(data_source.db), + "engine", + new_callable=PropertyMock, + return_value=MagicMock(), + ): + yield + + +class TestDataSourceApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceApi() + method = unwrap(api.get) + + binding = MagicMock( + id="b1", + provider="notion", + created_at="now", + disabled=False, + source_info={}, + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.db.session.scalars", + return_value=MagicMock(all=lambda: [binding]), + ), + ): + response, status = method(api) + + assert status == 200 + assert response["data"][0]["is_bound"] is True + + def test_get_no_bindings(self, app, patch_tenant): + api = DataSourceApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.db.session.scalars", + return_value=MagicMock(all=lambda: []), + ), + ): + response, status = method(api) + + assert status == 200 + assert response["data"] == [] + + def test_patch_enable_binding(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + response, status = method(api, "b1", "enable") + + assert status == 200 + assert binding.disabled is False + + def test_patch_disable_binding(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=False) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + response, status = method(api, "b1", "disable") + + assert status == 200 + assert binding.disabled is True + + def test_patch_binding_not_found(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = None + + with pytest.raises(NotFound): + method(api, "b1", "enable") + + def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=False) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + with pytest.raises(ValueError): + method(api, "b1", "enable") + + def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + with pytest.raises(ValueError): + method(api, "b1", "disable") + + +class TestDataSourceNotionListApi: + def test_get_credential_not_found(self, app, patch_tenant): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api) + + def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + page = MagicMock( + page_id="p1", + page_name="Page 1", + type="page", + parent_id="parent", + page_icon=None, + ) + + online_document_message = MagicMock( + result=[ + MagicMock( + workspace_id="w1", + workspace_name="My Workspace", + workspace_icon="icon", + pages=[page], + ) + ] + ) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=MagicMock( + get_online_document_pages=lambda **kw: iter([online_document_message]), + datasource_provider_type=lambda: None, + ), + ), + ): + response, status = method(api) + + assert status == 200 + + def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + page = MagicMock( + page_id="p1", + page_name="Page 1", + type="page", + parent_id="parent", + page_icon=None, + ) + + online_document_message = MagicMock( + result=[ + MagicMock( + workspace_id="w1", + workspace_name="My Workspace", + workspace_icon="icon", + pages=[page], + ) + ] + ) + + dataset = MagicMock(data_source_type="notion_import") + document = MagicMock(data_source_info='{"notion_page_id": "p1"}') + + with ( + app.test_request_context("/?credential_id=c1&dataset_id=ds1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=MagicMock( + get_online_document_pages=lambda **kw: iter([online_document_message]), + datasource_provider_type=lambda: None, + ), + ), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [document] + + response, status = method(api) + + assert status == 200 + + def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + dataset = MagicMock(data_source_type="other_type") + + with ( + app.test_request_context("/?credential_id=c1&dataset_id=ds1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.data_source.Session"), + ): + with pytest.raises(ValueError): + method(api) + + +class TestDataSourceNotionApi: + def test_get_preview_success(self, app, patch_tenant): + api = DataSourceNotionApi() + method = unwrap(api.get) + + extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")]) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"integration_secret": "t"}, + ), + patch( + "controllers.console.datasets.data_source.NotionExtractor", + return_value=extractor, + ), + ): + response, status = method(api, "p1", "page") + + assert status == 200 + + def test_post_indexing_estimate_success(self, app, patch_tenant): + api = DataSourceNotionApi() + method = unwrap(api.post) + + payload = { + "notion_info_list": [ + { + "workspace_id": "w1", + "credential_id": "c1", + "pages": [{"page_id": "p1", "type": "page"}], + } + ], + "process_rule": {"rules": {}}, + "doc_form": "text_model", + "doc_language": "English", + } + + with ( + app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}), + patch( + "controllers.console.datasets.data_source.DocumentService.estimate_args_validate", + ), + patch( + "controllers.console.datasets.data_source.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"total_pages": 1}), + ), + ): + response, status = method(api) + + assert status == 200 + + +class TestDataSourceNotionDatasetSyncApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceNotionDatasetSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id", + return_value=[MagicMock(id="d1")], + ), + patch( + "controllers.console.datasets.data_source.document_indexing_sync_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 200 + + def test_get_dataset_not_found(self, app, patch_tenant): + api = DataSourceNotionDatasetSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + +class TestDataSourceNotionDocumentSyncApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceNotionDocumentSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.document_indexing_sync_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_document_not_found(self, app, patch_tenant): + api = DataSourceNotionDocumentSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py new file mode 100644 index 0000000000..f9fc2ac397 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -0,0 +1,1926 @@ +import datetime +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.datasets import ( + DatasetApi, + DatasetApiBaseUrlApi, + DatasetApiDeleteApi, + DatasetApiKeyApi, + DatasetAutoDisableLogApi, + DatasetEnableApiApi, + DatasetErrorDocs, + DatasetIndexingEstimateApi, + DatasetIndexingStatusApi, + DatasetListApi, + DatasetPermissionUserListApi, + DatasetQueryApi, + DatasetRelatedAppListApi, + DatasetRetrievalSettingApi, + DatasetRetrievalSettingMockApi, + DatasetUseCheckApi, +) +from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.provider_manager import ProviderManager +from models.enums import CreatorUserRole +from models.model import ApiToken, UploadFile +from services.dataset_service import DatasetPermissionService, DatasetService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDatasetList: + def _mock_dataset_dict(self, **overrides): + base = { + "id": "ds-1", + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "permission": "only_me", + } + base.update(overrides) + return base + + def _mock_user(self): + user = MagicMock() + user.is_dataset_editor = True + return user + + def test_get_success_basic(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 1 + assert resp["data"][0]["embedding_available"] is True + + def test_get_with_ids_filter(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets?ids=1&ids=2"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets_by_ids", + return_value=(datasets, 2), + ) as by_ids_mock, + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + by_ids_mock.assert_called_once() + assert status == 200 + assert resp["total"] == 2 + + def test_get_with_tag_ids(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets?tag_ids=tag1"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert status == 200 + + def test_embedding_available_false(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [ + self._mock_dataset_dict( + indexing_technique="high_quality", + embedding_model="text-embed", + embedding_model_provider="openai", + ) + ] + + config = MagicMock() + config.get_models.return_value = [] # model not available + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=config, + ), + ): + resp, status = method(api) + + assert resp["data"][0]["embedding_available"] is False + + def test_partial_members_permission(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict(permission="partial_members")] + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.db.session.execute", + return_value=MagicMock(all=lambda: [("ds-1", "u1")]), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert resp["data"][0]["partial_member_list"] == ["u1"] + + +class TestDatasetListApiPost: + def test_post_success(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "My Dataset", + "description": "desc", + "indexing_technique": "economy", + "provider": "vendor", + } + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + # ---- minimal required fields for marshal ---- + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasetService, + "create_empty_dataset", + return_value=dataset, + ), + ): + _, status = method(api) + + assert status == 201 + + def test_post_forbidden(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = {"name": "test"} + + user = MagicMock() + user.is_dataset_editor = False + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_duplicate_name(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = {"name": "duplicate"} + + user = MagicMock() + user.is_dataset_editor = True + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasetService, + "create_empty_dataset", + side_effect=services.errors.dataset.DatasetNameDuplicateError(), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_post_invalid_payload_missing_name(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + with app.test_request_context("/datasets", json={}), patch.object(type(console_ns), "payload", {}): + with pytest.raises(ValueError): + method(api) + + def test_post_invalid_indexing_technique(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "bad", + "indexing_technique": "invalid-tech", + } + + with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError, match="Invalid indexing technique"): + method(api) + + def test_post_invalid_provider(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "bad", + "provider": "unknown", + } + + with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError, match="Invalid provider"): + method(api) + + +class TestDatasetApiGet: + def test_get_success_basic(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "123e4567-e89b-12d3-a456-426614174000" + + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + dataset.permission = "only_me" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + # embedding models exist → embedding_available stays True + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, status = method(api, dataset_id) + + assert status == 200 + assert data["embedding_available"] is True + + def test_get_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "missing-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_get_permission_denied(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + dataset = MagicMock() + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden, match="no access"): + method(api, dataset_id) + + def test_get_high_quality_embedding_unavailable(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model = "text-embedding" + dataset.embedding_model_provider = "openai" + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + dataset.permission = "only_me" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + # embedding model NOT configured + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, _ = method(api, dataset_id) + + assert data["embedding_available"] is False + + def test_get_partial_members_permission(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + dataset.permission = "partial_members" + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + partial_members = [{"id": "u1"}, {"id": "u2"}] + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=partial_members, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, _ = method(api, dataset_id) + + assert data["partial_member_list"] == partial_members + + +class TestDatasetApiPatch: + def test_patch_success_basic(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "name": "updated-name", + "description": "updated description", + } + + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.permission = "only_me" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=[], + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result["partial_member_list"] == [] + + def test_patch_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/datasets/missing"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, "missing") + + def test_patch_permission_denied(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + dataset = MagicMock() + + payload = {"name": "x"} + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetPermissionService, + "check_permission", + side_effect=Forbidden("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_patch_partial_members_update(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "permission": "partial_members", + "partial_member_list": [{"id": "u1"}, {"id": "u2"}], + } + + dataset = MagicMock() + dataset.id = dataset_id + dataset.permission = "partial_members" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "update_partial_member_list", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=payload["partial_member_list"], + ), + ): + result, _ = method(api, dataset_id) + + assert result["partial_member_list"] == payload["partial_member_list"] + + def test_patch_clear_partial_members(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "permission": "only_me", + } + + dataset = MagicMock() + dataset.id = dataset_id + dataset.permission = "only_me" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "clear_partial_member_list", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=[], + ), + ): + result, _ = method(api, dataset_id) + + assert result["partial_member_list"] == [] + + +class TestDatasetApiDelete: + def test_delete_success(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + return_value=True, + ), + patch.object( + DatasetPermissionService, + "clear_partial_member_list", + return_value=None, + ), + ): + result, status = method(api, dataset_id) + + assert status == 204 + assert result == {"result": "success"} + + def test_delete_forbidden_no_permission(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = False + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_delete_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "missing-dataset" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + return_value=False, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_delete_dataset_in_use(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + side_effect=services.errors.dataset.DatasetInUseError(), + ), + ): + with pytest.raises(DatasetInUseError): + method(api, dataset_id) + + +class TestDatasetUseCheckApi: + def test_get_use_check_true(self, app): + api = DatasetUseCheckApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}/use-check"), + patch.object( + DatasetService, + "dataset_use_check", + return_value=True, + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result == {"is_using": True} + + def test_get_use_check_false(self, app): + api = DatasetUseCheckApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}/use-check"), + patch.object( + DatasetService, + "dataset_use_check", + return_value=False, + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result == {"is_using": False} + + +class TestDatasetQueryApi: + def test_get_queries_success(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + current_user = MagicMock() + + dataset = MagicMock() + dataset.id = dataset_id + + queries = [MagicMock(), MagicMock()] + + with ( + app.test_request_context("/datasets/queries?page=1&limit=20"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetService, + "get_dataset_queries", + return_value=(queries, 2), + ), + ): + response, status = method(api, dataset_id) + + assert status == 200 + assert response["total"] == 2 + assert response["page"] == 1 + assert response["limit"] == 20 + assert response["has_more"] is False + assert len(response["data"]) == 2 + + def test_get_queries_dataset_not_found(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + with ( + app.test_request_context("/datasets/queries"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_get_queries_permission_denied(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + dataset = MagicMock() + + with ( + app.test_request_context("/datasets/queries"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_get_queries_pagination_has_more(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + dataset = MagicMock() + dataset.id = dataset_id + + queries = [MagicMock() for _ in range(20)] + + with ( + app.test_request_context("/datasets/queries?page=1&limit=20"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetService, + "get_dataset_queries", + return_value=(queries, 40), + ), + ): + response, status = method(api, dataset_id) + + assert status == 200 + assert response["has_more"] is True + assert len(response["data"]) == 20 + + +class TestDatasetIndexingEstimateApi: + def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key="key", + name="name.txt", + size=1, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user-1", + created_at=datetime.datetime.now(tz=datetime.UTC), + used=False, + ) + upload_file.id = file_id + return upload_file + + def _base_payload(self): + return { + "info_list": { + "data_source_type": "upload_file", + "file_info_list": { + "file_ids": ["file-1"], + }, + }, + "process_rule": {"chunk_size": 100}, + "indexing_technique": "high_quality", + "doc_form": "text_model", + "doc_language": "English", + "dataset_id": None, + } + + def test_post_success_upload_file(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + + payload = self._base_payload() + + mock_file = self._upload_file() + + mock_response = MagicMock() + mock_response.model_dump.return_value = {"tokens": 100} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + return_value=mock_response, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == {"tokens": 100} + + def test_post_file_not_found(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: None), + ), + ): + with pytest.raises(NotFound): + method(api) + + def test_post_llm_bad_request_error(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api) + + def test_post_provider_token_not_init(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api) + + def test_post_generic_exception(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(IndexingEstimateError): + method(api) + + +class TestDatasetRelatedAppListApi: + def test_get_success(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + dataset.id = "dataset-1" + + app1 = MagicMock() + app2 = MagicMock() + + join1 = MagicMock(app=app1) + join2 = MagicMock(app=app2) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_related_apps", + return_value=[join1, join2], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 2 + assert response["data"] == [app1, app2] + + def test_get_dataset_not_found(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") + + def test_get_permission_denied(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "dataset-1") + + def test_get_filters_none_apps(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + dataset.id = "dataset-1" + + app1 = MagicMock() + + join1 = MagicMock(app=app1) + join2 = MagicMock(app=None) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_related_apps", + return_value=[join1, join2], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 1 + assert response["data"] == [app1] + + +class TestDatasetIndexingStatusApi: + def test_get_success_with_documents(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + document = MagicMock() + document.id = "doc-1" + document.indexing_status = "completed" + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.completed_at = None + document.paused_at = None + document.error = None + document.stopped_at = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [document]), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert "data" in response + assert len(response["data"]) == 1 + + item = response["data"][0] + assert item["completed_segments"] == 3 + assert item["total_segments"] == 3 + + def test_get_success_no_documents(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: []), + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response == {"data": []} + + def test_segment_counts_different_values(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + document = MagicMock() + document.id = "doc-1" + document.indexing_status = "indexing" + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.completed_at = None + document.paused_at = None + document.error = None + document.stopped_at = None + + # First count = completed segments, second = total segments + query_mock = MagicMock() + query_mock.where.side_effect = [ + MagicMock(count=lambda: 2), + MagicMock(count=lambda: 5), + ] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [document]), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=query_mock, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + item = response["data"][0] + assert item["completed_segments"] == 2 + assert item["total_segments"] == 5 + + +class TestDatasetApiKeyApi: + def test_get_api_keys_success(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.get) + + mock_key_1 = MagicMock(spec=ApiToken) + mock_key_2 = MagicMock(spec=ApiToken) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_key_1, mock_key_2]), + ), + ): + response = method(api) + + assert "items" in response + assert response["items"] == [mock_key_1, mock_key_2] + + def test_post_create_api_key_success(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + ), + patch( + "controllers.console.datasets.datasets.ApiToken.generate_api_key", + return_value="dataset-abc123", + ), + patch( + "controllers.console.datasets.datasets.db.session.add", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.commit", + return_value=None, + ), + ): + response, status = method(api) + + assert status == 200 + assert isinstance(response, ApiToken) + assert response.token == "dataset-abc123" + assert response.type == "dataset" + + def test_post_exceed_max_keys(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + ), + ): + with pytest.raises(BadRequest) as exc_info: + method(api) + + assert exc_info.value.code == 400 + assert exc_info.value.data == { + "message": "Cannot create more than 10 API keys for this resource type.", + "custom": "max_keys_exceeded", + } + + +class TestDatasetApiDeleteApi: + def test_delete_success(self, app): + api = DatasetApiDeleteApi() + method = unwrap(api.delete) + + mock_key = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + ), + patch( + "controllers.console.datasets.datasets.db.session.commit", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.delete", + return_value=None, + ), + ): + response, status = method(api, "api-key-id") + + assert status == 204 + assert response["result"] == "success" + + def test_delete_key_not_found(self, app): + api = DatasetApiDeleteApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + ), + ): + with pytest.raises(NotFound): + method(api, "api-key-id") + + +class TestDatasetEnableApiApi: + def test_enable_api(self, app): + api = DatasetEnableApiApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", + return_value=None, + ), + ): + response, status = method(api, "dataset-1", "enable") + + assert status == 200 + assert response["result"] == "success" + + def test_disable_api(self, app): + api = DatasetEnableApiApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", + return_value=None, + ), + ): + response, status = method(api, "dataset-1", "disable") + + assert status == 200 + assert response["result"] == "success" + + +class TestDatasetApiBaseUrlApi: + def test_get_api_base_url_from_config(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + + def test_get_api_base_url_from_request(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("http://localhost:5000/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + None, + ), + ): + response = method(api) + + assert response["api_base_url"] == "http://localhost:5000/v1" + + +class TestDatasetRetrievalSettingApi: + def test_get_success(self, app): + api = DatasetRetrievalSettingApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.VECTOR_STORE", + "qdrant", + ), + patch( + "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", + return_value={"retrieval_method": ["semantic", "hybrid"]}, + ), + ): + response = method(api) + + assert "retrieval_method" in response + + +class TestDatasetRetrievalSettingMockApi: + def test_get_success(self, app): + api = DatasetRetrievalSettingMockApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", + return_value={"retrieval_method": ["semantic"]}, + ), + ): + response = method(api, "milvus") + + assert response["retrieval_method"] == ["semantic"] + + +class TestDatasetErrorDocs: + def test_get_success(self, app): + api = DatasetErrorDocs() + method = unwrap(api.get) + + dataset = MagicMock() + error_doc = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.get_error_documents_by_dataset_id", + return_value=[error_doc], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 1 + + def test_get_dataset_not_found(self, app): + api = DatasetErrorDocs() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") + + +class TestDatasetPermissionUserListApi: + def test_get_success(self, app): + api = DatasetPermissionUserListApi() + method = unwrap(api.get) + + dataset = MagicMock() + users = [{"id": "u1"}, {"id": "u2"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetPermissionService.get_dataset_partial_member_list", + return_value=users, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["data"] == users + + def test_get_permission_denied(self, app): + api = DatasetPermissionUserListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "dataset-1") + + +class TestDatasetAutoDisableLogApi: + def test_get_success(self, app): + api = DatasetAutoDisableLogApi() + method = unwrap(api.get) + + dataset = MagicMock() + logs = [{"reason": "quota"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset_auto_disable_logs", + return_value=logs, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response == logs + + def test_get_dataset_not_found(self, app): + api = DatasetAutoDisableLogApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py new file mode 100644 index 0000000000..dbe54ccb99 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -0,0 +1,1379 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.datasets.datasets_document import ( + DatasetDocumentListApi, + DocumentApi, + DocumentBatchDownloadZipApi, + DocumentBatchIndexingEstimateApi, + DocumentBatchIndexingStatusApi, + DocumentDownloadApi, + DocumentGenerateSummaryApi, + DocumentIndexingEstimateApi, + DocumentIndexingStatusApi, + DocumentMetadataApi, + DocumentPipelineExecutionLogApi, + DocumentProcessingApi, + DocumentRetryApi, + DocumentStatusApi, + DocumentSummaryStatusApi, + GetProcessRuleApi, +) +from controllers.console.datasets.error import ( + DocumentAlreadyFinishedError, + DocumentIndexingError, + IndexingEstimateError, + InvalidActionError, + InvalidMetadataError, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_ctx(): + return (MagicMock(is_dataset_editor=True, id="u1"), "tenant-1") + + +@pytest.fixture +def patch_tenant(tenant_ctx): + with patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=tenant_ctx, + ): + yield + + +@pytest.fixture +def dataset(): + return MagicMock(id="ds-1", indexing_technique="economy", summary_index_setting={"enable": True}) + + +@pytest.fixture +def document(): + return MagicMock( + id="doc-1", + tenant_id="tenant-1", + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + doc_form="text", + archived=False, + is_paused=False, + dataset_process_rule=None, + ) + + +@pytest.fixture +def patch_dataset(dataset): + with patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ): + yield + + +@pytest.fixture +def patch_permission(): + with patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ): + yield + + +class TestGetProcessRuleApi: + def test_get_default_success(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + response = method(api) + + assert "rules" in response + + def test_get_with_document_dataset_not_found(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api) + + +class TestDatasetDocumentListApi: + def test_get_with_fetch_true_counts_segments(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + doc = MagicMock(id="doc-1") + pagination = MagicMock(items=[doc], total=1) + + count_mock = MagicMock(return_value=2) + + with ( + app.test_request_context("/?fetch=true"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + resp = method(api, "ds-1") + + assert resp["data"] + + def test_get_with_search_status_and_created_at_sort(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?keyword=test&status=enabled&sort=created_at"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.apply_display_status_filter", + side_effect=lambda q, s: q, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + resp = method(api, "ds-1") + + assert resp["total"] == 1 + + def test_get_success(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_post_success(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.post) + + payload = {"indexing_technique": "economy"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", + return_value=([MagicMock()], "batch-1"), + ), + ): + response = method(api, "ds-1") + + assert "documents" in response + + def test_post_forbidden(self, app): + api = DatasetDocumentListApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/", json={}), + patch.object(type(console_ns), "payload", {}), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1") + + def test_get_with_fetch_true_and_invalid_fetch(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?fetch=maybe"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_get_sort_hit_count(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[], total=0) + + with ( + app.test_request_context("/?sort=hit_count"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 0 + + +class TestDocumentApi: + def test_get_success(self, app, patch_tenant): + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_invalid_metadata(self, app, patch_tenant): + api = DocumentApi() + method = unwrap(api.get) + + with app.test_request_context("/?metadata=wrong"), patch.object(api, "get_document", return_value=MagicMock()): + with pytest.raises(InvalidMetadataError): + method(api, "ds-1", "doc-1") + + def test_delete_success(self, app, patch_tenant, patch_dataset): + api = DocumentApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_document", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 204 + + def test_delete_indexing_error(self, app, patch_tenant, patch_dataset): + api = DocumentApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_document", + side_effect=services.errors.document.DocumentIndexingError(), + ), + ): + with pytest.raises(DocumentIndexingError): + method(api, "ds-1", "doc-1") + + +class TestDocumentDownloadApi: + def test_download_success(self, app, patch_tenant): + api = DocumentDownloadApi() + method = unwrap(api.get) + + document = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document_download_url", + return_value="url", + ), + ): + response = method(api, "ds-1", "doc-1") + + assert response["url"] == "url" + + +class TestDocumentProcessingApi: + def test_processing_forbidden_when_not_editor(self, app): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object(api, "get_document", return_value=MagicMock()), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1", "pause") + + def test_resume_from_error_state(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + doc = MagicMock(indexing_status="error", is_paused=True) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + _, status = method(api, "ds-1", "doc-1", "resume") + + assert status == 200 + + def test_resume_success(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="paused", is_paused=True) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "resume") + + assert status == 200 + + def test_pause_success(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="indexing") + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "pause") + + assert status == 200 + + def test_pause_invalid(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="completed") + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "pause") + + +class TestDocumentMetadataApi: + def test_put_metadata_schema_filtering(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + doc = MagicMock() + + payload = { + "doc_type": "invoice", + "doc_metadata": {"amount": 10, "invalid": "x"}, + } + + schema = {"amount": int} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"invoice": schema}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + method(api, "ds-1", "doc-1") + + assert doc.doc_metadata == {"amount": 10} + + def test_put_success(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + document = MagicMock() + + payload = {"doc_type": "others", "doc_metadata": {"a": 1}} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"others": {}}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_put_invalid_payload(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + with app.test_request_context("/", json={}), patch.object(api, "get_document", return_value=MagicMock()): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + def test_put_invalid_doc_type(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + payload = {"doc_type": "invalid", "doc_metadata": {}} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"others": {}}, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + +class TestDocumentStatusApi: + def test_patch_success(self, app, patch_tenant, patch_dataset): + api = DocumentStatusApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.batch_update_document_status", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "enable") + + assert status == 200 + + def test_patch_invalid_action(self, app, patch_tenant, patch_dataset): + api = DocumentStatusApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.batch_update_document_status", + side_effect=ValueError("x"), + ), + ): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "enable") + + +class TestDocumentRetryApi: + def test_retry_archived_document_skipped(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + doc = MagicMock(indexing_status="indexing") + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=doc, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.check_archived", + return_value=True, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + ) as retry_mock, + ): + resp, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", []) + + def test_retry_success(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + document = MagicMock(indexing_status="indexing", archived=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.check_archived", + return_value=False, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + return_value=None, + ) as retry_mock, + ): + response, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", [document]) + + def test_retry_skips_completed_document(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + document = MagicMock(indexing_status="completed", archived=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + return_value=None, + ) as retry_mock, + ): + response, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", []) + + +class TestDocumentPipelineExecutionLogApi: + def test_get_log_success(self, app, patch_tenant, patch_dataset): + api = DocumentPipelineExecutionLogApi() + method = unwrap(api.get) + + log = MagicMock( + datasource_info="{}", + datasource_type="file", + input_data={}, + datasource_node_id="n1", + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) + ), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentGenerateSummaryApi: + def test_generate_summary_missing_documents(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock( + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + payload = {"document_list": ["doc-1", "doc-2"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_documents_by_ids", + return_value=[MagicMock(id="doc-1")], + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + def test_generate_not_enabled(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock(indexing_technique="high_quality", summary_index_setting={"enable": False}) + + payload = {"document_list": ["doc-1"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1") + + def test_generate_summary_success_with_qa_skip(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock( + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + doc1 = MagicMock(id="doc-1", doc_form="qa_model") + doc2 = MagicMock(id="doc-2", doc_form="text") + + payload = {"document_list": ["doc-1", "doc-2"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_documents_by_ids", + return_value=[doc1, doc2], + ), + patch( + "controllers.console.datasets.datasets_document.generate_summary_index_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 200 + + +class TestDocumentSummaryStatusApi: + def test_get_success(self, app, patch_tenant, patch_permission): + api = DocumentSummaryStatusApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "services.summary_index_service.SummaryIndexService.get_document_summary_status_detail", + return_value={"total_segments": 0}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentIndexingEstimateApi: + def test_indexing_estimate_file_not_found(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + query_mock = MagicMock() + query_mock.where.return_value.first.return_value = None + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=query_mock, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_indexing_estimate_generic_exception(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + upload_file = MagicMock() + + mock_indexing_runner = MagicMock() + mock_indexing_runner.indexing_estimate.side_effect = RuntimeError("Some indexing error") + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) + ), + ), + patch( + "controllers.console.datasets.datasets_document.ExtractSetting", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner", + return_value=mock_indexing_runner, + ), + ): + with pytest.raises(IndexingEstimateError): + method(api, "ds-1", "doc-1") + + def test_get_finished(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock(indexing_status="completed") + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(DocumentAlreadyFinishedError): + method(api, "ds-1", "doc-1") + + +class TestDocumentBatchDownloadZipApi: + def test_post_no_documents(self, app, patch_tenant): + api = DocumentBatchDownloadZipApi() + method = unwrap(api.post) + + payload = {"document_ids": []} + + with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError): + method(api, "ds-1") + + +class TestDatasetDocumentListApiDelete: + def test_delete_success(self, app, patch_tenant, patch_dataset): + """Test successful deletion of documents""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1&document_id=doc-2"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_documents", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 204 + + def test_delete_indexing_error(self, app, patch_tenant, patch_dataset): + """Test deletion with indexing error""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_documents", + side_effect=services.errors.document.DocumentIndexingError(), + ), + ): + with pytest.raises(DocumentIndexingError): + method(api, "ds-1") + + def test_delete_dataset_not_found(self, app, patch_tenant): + """Test deletion when dataset not found""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + +class TestDocumentBatchIndexingEstimateApi: + def test_batch_indexing_estimate_website(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + doc = MagicMock( + indexing_status="indexing", + data_source_type="website_crawl", + data_source_info_dict={ + "provider": "firecrawl", + "job_id": "j1", + "url": "https://x.com", + "mode": "single", + "only_main_content": True, + }, + doc_form="text", + ) + + with ( + app.test_request_context("/"), + patch.object(api, "get_batch_documents", return_value=[doc]), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 2}), + ), + ): + resp, status = method(api, "ds-1", "batch-1") + + assert status == 200 + + def test_batch_indexing_estimate_notion(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + doc = MagicMock( + indexing_status="indexing", + data_source_type="notion_import", + data_source_info_dict={ + "credential_id": "c1", + "notion_workspace_id": "w1", + "notion_page_id": "p1", + "type": "page", + }, + doc_form="text", + ) + + with ( + app.test_request_context("/"), + patch.object(api, "get_batch_documents", return_value=[doc]), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 1}), + ), + ): + resp, status = method(api, "ds-1", "batch-1") + + assert status == 200 + + def test_batch_estimate_unsupported_datasource(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="unknown", + data_source_info_dict={}, + doc_form="text", + ) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): + with pytest.raises(ValueError): + method(api, "ds-1", "batch-1") + + def test_get_batch_estimate_invalid_batch(self, app, patch_tenant): + """Test batch estimation with invalid batch""" + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-batch") + + +class TestDocumentBatchIndexingStatusApi: + def test_get_batch_status_invalid_batch(self, app, patch_tenant): + """Test batch status with invalid batch""" + api = DocumentBatchIndexingStatusApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-batch") + + +class TestDocumentIndexingStatusApi: + def test_get_status_document_not_found(self, app, patch_tenant): + """Test getting status for non-existent document""" + api = DocumentIndexingStatusApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_document", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-doc") + + +class TestDocumentApiMetadata: + def test_get_with_only_option(self, app, patch_tenant): + """Test get with 'only' metadata option""" + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None, doc_metadata_details=[]) + + with ( + app.test_request_context("/?metadata=only"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_with_without_option(self, app, patch_tenant): + """Test get with 'without' metadata option""" + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None) + + with ( + app.test_request_context("/?metadata=without"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentGenerateSummaryApiSuccess: + def test_generate_not_enabled_high_quality(self, app, patch_tenant, patch_permission): + """Test summary generation on non-high-quality dataset""" + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock(indexing_technique="economy", summary_index_setting={"enable": True}) + + payload = {"document_list": ["doc-1"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1") + + +class TestDocumentProcessingApiResume: + def test_resume_invalid_status(self, app, patch_tenant): + """Test resume on non-paused document""" + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="completed", is_paused=False) + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "resume") + + +class TestDocumentPermissionCases: + def test_document_batch_get_permission_denied(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "batch-1") + + def test_document_batch_get_documents_not_found(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch.object(api, "get_batch_documents", return_value=None), + ): + response, status = method(api, "ds-1", "batch-1") + + assert status == 200 + assert response == { + "tokens": 0, + "total_price": 0, + "currency": "USD", + "total_segments": 0, + "preview": [], + } + + def test_document_tenant_mismatch(self, app): + api = DocumentApi() + method = unwrap(api.get) + + user = MagicMock(is_dataset_editor=True) + document = MagicMock( + tenant_id="other-tenant", + dataset_process_rule=None, + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), # ✅ prevents real DB call + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + def test_process_rule_get_by_document_success(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + process_rule = MagicMock(mode="custom", rules_dict={"a": 1}) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + where=lambda *a: MagicMock( + order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) + ) + ), + ), + ): + result = method(api) + + if isinstance(result, tuple): + response, status = result + else: + response, status = result, 200 + + assert status == 200 + assert response["mode"] == "custom" + + def test_process_rule_permission_denied(self, app): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(MagicMock(is_dataset_editor=True), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestDocumentListAdvancedCases: + def test_document_list_with_multiple_sort_options(self, app, patch_tenant, patch_dataset, patch_permission): + """Test document list with different sort options""" + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?sort=updated_at"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_document_metadata_with_schema_validation(self, app, patch_tenant): + """Test document metadata update with schema validation""" + api = DocumentMetadataApi() + method = unwrap(api.put) + + doc = MagicMock() + payload = { + "doc_type": "contract", + "doc_metadata": {"amount": 5000, "currency": "USD", "invalid_field": "x"}, + } + + schema = {"amount": int, "currency": str} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"contract": schema}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert doc.doc_metadata == {"amount": 5000, "currency": "USD"} + + +class TestDocumentIndexingEdgeCases: + def test_document_indexing_with_extraction_setting(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + upload_file = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_document.ExtractSetting", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 5}), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py new file mode 100644 index 0000000000..e67e4daad9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -0,0 +1,1252 @@ +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.datasets_segments import ( + ChildChunkAddApi, + ChildChunkUpdateApi, + DatasetDocumentSegmentAddApi, + DatasetDocumentSegmentApi, + DatasetDocumentSegmentBatchImportApi, + DatasetDocumentSegmentListApi, + DatasetDocumentSegmentUpdateApi, + _get_segment_with_summary, +) +from controllers.console.datasets.error import ( + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + InvalidActionError, +) +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from models.dataset import ChildChunk, DocumentSegment +from models.model import UploadFile + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _segment(): + return SimpleNamespace( + id="s1", + position=1, + document_id="d1", + content="c", + sign_content="c", + answer="a", + word_count=1, + tokens=1, + keywords=[], + index_node_id="n1", + index_node_hash="h", + hit_count=0, + enabled=True, + disabled_at=None, + disabled_by=None, + status="normal", + created_by="u1", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + updated_by="u1", + indexing_at=None, + completed_at=None, + error=None, + stopped_at=None, + child_chunks=[], + attachments=[], + summary=None, + ) + + +def test_get_segment_with_summary(monkeypatch): + segment = _segment() + summary = SimpleNamespace(summary_content="summary") + + monkeypatch.setattr( + "services.summary_index_service.SummaryIndexService.get_segment_summary", + lambda *_args, **_kwargs: summary, + ) + + result = _get_segment_with_summary(segment, dataset_id="d1") + + assert result["summary"] == "summary" + + +class TestDatasetDocumentSegmentListApi: + def test_get_success(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + document = MagicMock() + + segment = MagicMock(spec=DocumentSegment) + segment.id = "seg-1" + + pagination = MagicMock() + pagination.items = [segment] + pagination.total = 1 + pagination.pages = 1 + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.paginate", + return_value=pagination, + ), + patch( + "services.summary_index_service.SummaryIndexService.get_segments_summaries", + return_value={}, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_dataset_not_found(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_get_permission_denied(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + +class TestDatasetDocumentSegmentApi: + def test_patch_success(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.id = "doc-1" + + with ( + app.test_request_context("/?segment_id=s1&segment_id=s2"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.update_segments_status", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "enable") + + assert status == 200 + assert response["result"] == "success" + + def test_patch_document_indexing_in_progress(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.id = "doc-1" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=b"running", + ), + ): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "disable") + + def test_patch_llm_bad_request(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock(id="doc-1") + + with ( + app.test_request_context("/?segment_id=s1"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "enable") + + def test_patch_provider_token_not_init(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock(id="doc-1") + + with ( + app.test_request_context("/?segment_id=s1"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "enable") + + +class TestDatasetDocumentSegmentAddApi: + def test_post_success(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "hello"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.doc_form = "text" + + segment = MagicMock() + segment.id = "seg-1" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.segment_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_segment", + return_value=segment, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "seg-1"}, + ), + patch( + "controllers.console.datasets.datasets_segments._get_segment_with_summary", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert response["data"]["id"] == "seg-1" + + def test_post_llm_bad_request(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1") + + def test_post_provider_token_not_init(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1") + + +class TestDatasetDocumentSegmentUpdateApi: + def test_patch_success(self, app): + api = DatasetDocumentSegmentUpdateApi() + method = unwrap(api.patch) + + payload = {"content": "updated"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.doc_form = "text" + + segment = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.segment_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.update_segment", + return_value=segment, + ), + patch( + "controllers.console.datasets.datasets_segments._get_segment_with_summary", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1") + + assert status == 200 + assert "data" in response + + def test_patch_llm_bad_request(self, app): + api = DatasetDocumentSegmentUpdateApi() + method = unwrap(api.patch) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "seg-1") + + +class TestDatasetDocumentSegmentBatchImportApi: + def test_post_success(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock(spec=UploadFile) + upload_file.name = "test.csv" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.setnx", + return_value=True, + ), + patch( + "controllers.console.datasets.datasets_segments.batch_create_segment_to_index_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert response["job_status"] == "waiting" + + def test_post_dataset_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_document_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_upload_file_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_invalid_file_type(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock() + upload_file.name = "test.txt" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + def test_post_async_task_failure(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock() + upload_file.name = "test.csv" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.setnx", + side_effect=Exception("redis down"), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 500 + assert "error" in response + + def test_get_job_not_found_in_redis(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, job_id="job-1") + + +class TestChildChunkAddApi: + def test_post_success(self, app): + api = ChildChunkAddApi() + method = unwrap(api.post) + + payload = {"content": "child"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock(spec=ChildChunk) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_child_chunk", + return_value=child_chunk, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "cc-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1") + + assert status == 200 + assert response["data"]["id"] == "cc-1" + + def test_post_child_chunk_indexing_error(self, app): + api = ChildChunkAddApi() + method = unwrap(api.post) + + payload = {"content": "child"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock(indexing_technique="economy") + document = MagicMock() + segment = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_child_chunk", + side_effect=services.errors.chunk.ChildChunkIndexingError("fail"), + ), + ): + with pytest.raises(ChildChunkIndexingError): + method(api, "ds-1", "doc-1", "seg-1") + + +class TestChildChunkUpdateApi: + def test_delete_success(self, app): + api = ChildChunkUpdateApi() + method = unwrap(api.delete) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + side_effect=[ + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), + ], + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.delete_child_chunk", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1", "cc-1") + + assert status == 204 + assert response["result"] == "success" + + def test_delete_child_chunk_index_error(self, app): + api = ChildChunkUpdateApi() + method = unwrap(api.delete) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock() + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + side_effect=[ + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), + ], + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.delete_child_chunk", + side_effect=services.errors.chunk.ChildChunkDeleteIndexError("fail"), + ), + ): + with pytest.raises(ChildChunkDeleteIndexError): + method(api, "ds-1", "doc-1", "seg-1", "cc-1") + + +class TestSegmentListAdvancedCases: + def test_segment_list_with_keyword_filter(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + document = MagicMock() + + segment = MagicMock(spec=DocumentSegment) + segment.id = "seg-1" + segment.keywords = ["test"] + segment.enabled = True + + pagination = MagicMock(items=[segment], total=1, pages=1) + + with ( + app.test_request_context("/?keyword=test"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.paginate", + return_value=pagination, + ), + patch( + "services.summary_index_service.SummaryIndexService.get_segments_summaries", + return_value={}, + ), + ): + result = method(api, "ds-1", "doc-1") + + if isinstance(result, tuple): + response, status = result + else: + response, status = result, 200 + + assert status == 200 + assert response["total"] == 1 + + def test_segment_list_permission_denied(self, app): + """Test segment list with permission denied""" + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + def test_segment_list_dataset_not_found(self, app): + """Test segment list with dataset not found""" + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + +class TestSegmentOperationCases: + def test_segment_add_with_provider_token_error(self, app): + """Test segment add with provider token not initialized""" + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + + payload = {"content": "new content", "answer": None} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_segment", + side_effect=ProviderTokenNotInitError("Token not init"), + ), + ): + with pytest.raises(ProviderTokenNotInitError): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_document_not_found(self, app): + """Test batch import with document not found""" + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_invalid_file(self, app): + """Test batch import with invalid file type""" + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + upload_file = None # File not found + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_async_task_failure(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + upload_file = MagicMock(spec=UploadFile, extension="csv", id="file-1") + upload_file.name = "test.csv" + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.batch_create_segment_to_index_task.delay", + side_effect=Exception("Task failed"), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 500 + assert "error" in response + + def test_batch_import_get_job_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.get) + + user = MagicMock(is_dataset_editor=True) + + with ( + app.test_request_context("/?job_id=invalid-job"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, "invalid-job") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py new file mode 100644 index 0000000000..161d0c41e8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -0,0 +1,399 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.datasets.external import ( + BedrockRetrievalApi, + ExternalApiTemplateApi, + ExternalApiTemplateListApi, + ExternalDatasetCreateApi, + ExternalKnowledgeHitTestingApi, +) +from services.dataset_service import DatasetService +from services.external_knowledge_service import ExternalDatasetService +from services.hit_testing_service import HitTestingService +from services.knowledge_service import ExternalDatasetTestService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_external_dataset") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def current_user(): + user = MagicMock() + user.id = "user-1" + user.is_dataset_editor = True + user.has_edit_permission = True + user.is_dataset_operator = True + return user + + +@pytest.fixture(autouse=True) +def mock_auth(mocker, current_user): + mocker.patch( + "controllers.console.datasets.external.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ) + + +class TestExternalApiTemplateListApi: + def test_get_success(self, app): + api = ExternalApiTemplateListApi() + method = unwrap(api.get) + + api_item = MagicMock() + api_item.to_dict.return_value = {"id": "1"} + + with ( + app.test_request_context("/?page=1&limit=20"), + patch.object( + ExternalDatasetService, + "get_external_knowledge_apis", + return_value=([api_item], 1), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 1 + assert resp["data"][0]["id"] == "1" + + def test_post_forbidden(self, app, current_user): + current_user.is_dataset_editor = False + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "x", "settings": {"k": "v"}} + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list"), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_duplicate_name(self, app): + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "x", "settings": {"k": "v"}} + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list"), + patch.object( + ExternalDatasetService, + "create_external_knowledge_api", + side_effect=services.errors.dataset.DatasetNameDuplicateError(), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + +class TestExternalApiTemplateApi: + def test_get_not_found(self, app): + api = ExternalApiTemplateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "get_external_knowledge_api", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "api-id") + + def test_delete_forbidden(self, app, current_user): + current_user.has_edit_permission = False + current_user.is_dataset_operator = False + + api = ExternalApiTemplateApi() + method = unwrap(api.delete) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "api-id") + + +class TestExternalDatasetCreateApi: + def test_create_success(self, app): + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + payload = { + "external_knowledge_api_id": "api", + "external_knowledge_id": "kid", + "name": "dataset", + } + + dataset = MagicMock() + + dataset.embedding_available = False + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.enable_qa = False + dataset.enable_vector_store = False + dataset.vector_store_setting = None + dataset.is_multimodal = False + + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object( + ExternalDatasetService, + "create_external_dataset", + return_value=dataset, + ), + ): + _, status = method(api) + + assert status == 201 + + def test_create_forbidden(self, app, current_user): + current_user.is_dataset_editor = False + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + payload = { + "external_knowledge_api_id": "api", + "external_knowledge_id": "kid", + "name": "dataset", + } + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestExternalKnowledgeHitTestingApi: + def test_hit_testing_dataset_not_found(self, app): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-id") + + def test_hit_testing_success(self, app): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + payload = {"query": "hello"} + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission"), + patch.object( + HitTestingService, + "external_retrieve", + return_value={"ok": True}, + ), + ): + resp = method(api, "dataset-id") + + assert resp["ok"] is True + + +class TestBedrockRetrievalApi: + def test_bedrock_retrieval(self, app): + api = BedrockRetrievalApi() + method = unwrap(api.post) + + payload = { + "retrieval_setting": {}, + "query": "hello", + "knowledge_id": "kid", + } + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object( + ExternalDatasetTestService, + "knowledge_retrieval", + return_value={"ok": True}, + ), + ): + resp, status = method() + + assert status == 200 + assert resp["ok"] is True + + +class TestExternalApiTemplateListApiAdvanced: + def test_post_duplicate_name_error(self, app, mock_auth, current_user): + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "duplicate_api", "settings": {"key": "value"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"), + patch( + "controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api", + side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_get_with_pagination(self, app, mock_auth, current_user): + api = ExternalApiTemplateListApi() + method = unwrap(api.get) + + templates = [MagicMock(id=f"api-{i}") for i in range(3)] + + with ( + app.test_request_context("/?page=1&limit=20"), + patch( + "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis", + return_value=(templates, 25), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 25 + assert len(resp["data"]) == 3 + + +class TestExternalDatasetCreateApiAdvanced: + def test_create_forbidden(self, app, mock_auth, current_user): + """Test creating external dataset without permission""" + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + current_user.is_dataset_editor = False + + payload = { + "external_knowledge_api_id": "api-1", + "external_knowledge_id": "ek-1", + "name": "new_dataset", + "description": "A dataset", + } + + with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(Forbidden): + method(api) + + +class TestExternalKnowledgeHitTestingApiAdvanced: + def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user): + """Test hit testing on non-existent dataset""" + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "test query", + "external_retrieval_model": None, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + dataset = MagicMock() + payload = { + "query": "test query", + "external_retrieval_model": {"type": "bm25"}, + "metadata_filtering_conditions": {"status": "active"}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"), + patch( + "controllers.console.datasets.external.HitTestingService.external_retrieve", + return_value={"results": []}, + ), + ): + resp = method(api, "ds-1") + + assert resp["results"] == [] + + +class TestBedrockRetrievalApiAdvanced: + def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user): + api = BedrockRetrievalApi() + method = unwrap(api.post) + + payload = { + "retrieval_setting": {}, + "query": "test", + "knowledge_id": "k-1", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval", + side_effect=ValueError("Invalid settings"), + ), + ): + with pytest.raises(ValueError): + method() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py new file mode 100644 index 0000000000..55fb038156 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -0,0 +1,160 @@ +import uuid +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.console import console_ns +from controllers.console.datasets.hit_testing import HitTestingApi +from controllers.console.datasets.hit_testing_base import HitTestingPayload + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_hit_testing") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def dataset_id(): + return uuid.uuid4() + + +@pytest.fixture +def dataset(): + return MagicMock(id="dataset-1") + + +@pytest.fixture(autouse=True) +def bypass_decorators(mocker): + """Bypass all decorators on the API method.""" + mocker.patch( + "controllers.console.datasets.hit_testing.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.login_required", + return_value=lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.account_initialization_required", + return_value=lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check", + return_value=lambda *_: (lambda f: f), + ) + + +class TestHitTestingApi: + def test_hit_testing_success(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + "top_k": 3, + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": "what is vector search", "records": []}, + ), + ): + result = method(api, dataset_id) + + assert "query" in result + assert "records" in result + assert result["records"] == [] + + def test_hit_testing_dataset_not_found(self, app, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "test", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + side_effect=NotFound("Dataset not found"), + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_hit_testing_invalid_args(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + side_effect=ValueError("Invalid parameters"), + ), + ): + with pytest.raises(ValueError, match="Invalid parameters"): + method(api, dataset_id) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py new file mode 100644 index 0000000000..e7ae37ae45 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -0,0 +1,207 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import DatasetNotInitializedError +from controllers.console.datasets.hit_testing_base import ( + DatasetsHitTestingBase, +) +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from models.account import Account +from services.dataset_service import DatasetService +from services.hit_testing_service import HitTestingService + + +@pytest.fixture +def account(): + acc = MagicMock(spec=Account) + return acc + + +@pytest.fixture(autouse=True) +def patch_current_user(mocker, account): + """Patch current_user to a valid Account.""" + mocker.patch( + "controllers.console.datasets.hit_testing_base.current_user", + account, + ) + + +@pytest.fixture +def dataset(): + return MagicMock(id="dataset-1") + + +class TestGetAndValidateDataset: + def test_success(self, dataset): + with ( + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + ): + result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + assert result == dataset + + def test_dataset_not_found(self): + with patch.object( + DatasetService, + "get_dataset", + return_value=None, + ): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + def test_permission_denied(self, dataset): + with ( + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden, match="no access"): + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + +class TestHitTestingArgsCheck: + def test_args_check_called(self): + args = {"query": "test"} + + with patch.object( + HitTestingService, + "hit_testing_args_check", + ) as check_mock: + DatasetsHitTestingBase.hit_testing_args_check(args) + + check_mock.assert_called_once_with(args) + + +class TestParseArgs: + def test_parse_args_success(self): + payload = {"query": "hello"} + + result = DatasetsHitTestingBase.parse_args(payload) + + assert result["query"] == "hello" + + def test_parse_args_invalid(self): + payload = {"query": "x" * 300} + + with pytest.raises(ValueError): + DatasetsHitTestingBase.parse_args(payload) + + +class TestPerformHitTesting: + def test_success(self, dataset): + response = { + "query": "hello", + "records": [], + } + + with patch.object( + HitTestingService, + "retrieve", + return_value=response, + ): + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + assert result["query"] == "hello" + assert result["records"] == [] + + def test_index_not_initialized(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=services.errors.index.IndexNotInitializedError(), + ): + with pytest.raises(DatasetNotInitializedError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_provider_token_not_init(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ProviderTokenNotInitError("token missing"), + ): + with pytest.raises(ProviderNotInitializeError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_quota_exceeded(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=QuotaExceededError(), + ): + with pytest.raises(ProviderQuotaExceededError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_model_not_supported(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ModelCurrentlyNotSupportError(), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_llm_bad_request(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=LLMBadRequestError("bad request"), + ): + with pytest.raises(ProviderNotInitializeError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_invoke_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=InvokeError("invoke failed"), + ): + with pytest.raises(CompletionRequestError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_value_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ValueError("bad args"), + ): + with pytest.raises(ValueError, match="bad args"): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_unexpected_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=Exception("boom"), + ): + with pytest.raises(InternalServerError, match="boom"): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py new file mode 100644 index 0000000000..de834c2d4d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -0,0 +1,362 @@ +import uuid +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.console import console_ns +from controllers.console.datasets.metadata import ( + DatasetMetadataApi, + DatasetMetadataBuiltInFieldActionApi, + DatasetMetadataBuiltInFieldApi, + DatasetMetadataCreateApi, + DocumentMetadataEditApi, +) +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.knowledge_entities import ( + MetadataArgs, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_dataset_metadata") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def current_user(): + user = MagicMock() + user.id = "user-1" + return user + + +@pytest.fixture +def dataset(): + ds = MagicMock() + ds.id = "dataset-1" + return ds + + +@pytest.fixture +def dataset_id(): + return uuid.uuid4() + + +@pytest.fixture +def metadata_id(): + return uuid.uuid4() + + +@pytest.fixture(autouse=True) +def bypass_decorators(mocker): + """Bypass setup/login/license decorators.""" + mocker.patch( + "controllers.console.datasets.metadata.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.login_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.account_initialization_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.enterprise_license_required", + lambda f: f, + ) + + +class TestDatasetMetadataCreateApi: + def test_create_metadata_success(self, app, current_user, dataset, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.post) + + payload = {"name": "author"} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + MetadataArgs, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "create_metadata", + return_value={"id": "m1", "name": "author"}, + ), + ): + result, status = method(api, dataset_id) + + assert status == 201 + assert result["name"] == "author" + + def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.post) + + valid_payload = { + "type": "string", + "name": "author", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=valid_payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + MetadataArgs, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + +class TestDatasetMetadataGetApi: + def test_get_metadata_success(self, app, dataset, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + MetadataService, + "get_dataset_metadatas", + return_value=[{"id": "m1"}], + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert isinstance(result, list) + + def test_get_metadata_dataset_not_found(self, app, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, dataset_id) + + +class TestDatasetMetadataApi: + def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id): + api = DatasetMetadataApi() + method = unwrap(api.patch) + + payload = {"name": "updated-name"} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "update_metadata_name", + return_value={"id": "m1", "name": "updated-name"}, + ), + ): + result, status = method(api, dataset_id, metadata_id) + + assert status == 200 + assert result["name"] == "updated-name" + + def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id): + api = DatasetMetadataApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "delete_metadata", + ), + ): + result, status = method(api, dataset_id, metadata_id) + + assert status == 204 + assert result["result"] == "success" + + +class TestDatasetMetadataBuiltInFieldApi: + def test_get_built_in_fields(self, app): + api = DatasetMetadataBuiltInFieldApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + MetadataService, + "get_built_in_fields", + return_value=["title", "source"], + ), + ): + result, status = method(api) + + assert status == 200 + assert result["fields"] == ["title", "source"] + + +class TestDatasetMetadataBuiltInFieldActionApi: + def test_enable_built_in_field(self, app, current_user, dataset, dataset_id): + api = DatasetMetadataBuiltInFieldActionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "enable_built_in_field", + ), + ): + result, status = method(api, dataset_id, "enable") + + assert status == 200 + assert result["result"] == "success" + + +class TestDocumentMetadataEditApi: + def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id): + api = DocumentMetadataEditApi() + method = unwrap(api.post) + + payload = {"operation": "add", "metadata": {}} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataOperationData, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + MetadataService, + "update_documents_metadata", + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/datasets/test_website.py b/api/tests/unit_tests/controllers/console/datasets/test_website.py new file mode 100644 index 0000000000..9f0da6e76f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_website.py @@ -0,0 +1,233 @@ +from unittest.mock import Mock, PropertyMock, patch + +import pytest +from flask import Flask + +from controllers.console import console_ns +from controllers.console.datasets.error import WebsiteCrawlError +from controllers.console.datasets.website import ( + WebsiteCrawlApi, + WebsiteCrawlStatusApi, +) +from services.website_service import ( + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_website_crawl") + app.config["TESTING"] = True + return app + + +@pytest.fixture(autouse=True) +def bypass_auth_and_setup(mocker): + """Bypass setup/login/account decorators.""" + mocker.patch( + "controllers.console.datasets.website.login_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.website.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.website.account_initialization_required", + lambda f: f, + ) + + +class TestWebsiteCrawlApi: + def test_crawl_success(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "https://example.com", + "options": {"depth": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mock_request = Mock(spec=WebsiteCrawlApiRequest) + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "crawl_url", + return_value={"job_id": "job-1"}, + ) + + result, status = method(api) + + assert status == 200 + assert result["job_id"] == "job-1" + + def test_crawl_invalid_payload(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "bad-url", + "options": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + side_effect=ValueError("invalid payload"), + ) + + with pytest.raises(WebsiteCrawlError, match="invalid payload"): + method(api) + + def test_crawl_service_error(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "https://example.com", + "options": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mock_request = Mock(spec=WebsiteCrawlApiRequest) + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "crawl_url", + side_effect=Exception("crawl failed"), + ) + + with pytest.raises(WebsiteCrawlError, match="crawl failed"): + method(api) + + +class TestWebsiteCrawlStatusApi: + def test_get_status_success(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mock_request = Mock(spec=WebsiteCrawlStatusApiRequest) + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "get_crawl_status_typed", + return_value={"status": "completed"}, + ) + + result, status = method(api, job_id) + + assert status == 200 + assert result["status"] == "completed" + + def test_get_status_invalid_provider(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + side_effect=ValueError("invalid provider"), + ) + + with pytest.raises(WebsiteCrawlError, match="invalid provider"): + method(api, job_id) + + def test_get_status_service_error(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mock_request = Mock(spec=WebsiteCrawlStatusApiRequest) + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "get_crawl_status_typed", + side_effect=Exception("status lookup failed"), + ) + + with pytest.raises(WebsiteCrawlError, match="status lookup failed"): + method(api, job_id) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py new file mode 100644 index 0000000000..90f00711c1 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -0,0 +1,117 @@ +from unittest.mock import Mock + +import pytest + +from controllers.console.datasets.error import PipelineNotFoundError +from controllers.console.datasets.wraps import get_rag_pipeline +from models.dataset import Pipeline + + +class TestGetRagPipeline: + def test_missing_pipeline_id(self): + @get_rag_pipeline + def dummy_view(**kwargs): + return "ok" + + with pytest.raises(ValueError, match="missing pipeline_id"): + dummy_view() + + def test_pipeline_not_found(self, mocker): + @get_rag_pipeline + def dummy_view(**kwargs): + return "ok" + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = None + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + with pytest.raises(PipelineNotFoundError): + dummy_view(pipeline_id="pipeline-1") + + def test_pipeline_found_and_injected(self, mocker): + pipeline = Mock(spec=Pipeline) + pipeline.id = "pipeline-1" + pipeline.tenant_id = "tenant-1" + + @get_rag_pipeline + def dummy_view(**kwargs): + return kwargs["pipeline"] + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = pipeline + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id="pipeline-1") + + assert result is pipeline + + def test_pipeline_id_removed_from_kwargs(self, mocker): + pipeline = Mock(spec=Pipeline) + + @get_rag_pipeline + def dummy_view(**kwargs): + assert "pipeline_id" not in kwargs + return "ok" + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = pipeline + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id="pipeline-1") + + assert result == "ok" + + def test_pipeline_id_cast_to_string(self, mocker): + pipeline = Mock(spec=Pipeline) + + @get_rag_pipeline + def dummy_view(**kwargs): + return kwargs["pipeline"] + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + def where_side_effect(*args, **kwargs): + assert args[0].right.value == "123" + return Mock(first=lambda: pipeline) + + mock_query = Mock() + mock_query.where.side_effect = where_side_effect + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id=123) + + assert result is pipeline diff --git a/api/tests/unit_tests/controllers/console/explore/__init__.py b/api/tests/unit_tests/controllers/console/explore/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py new file mode 100644 index 0000000000..0afbc5a8f7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -0,0 +1,402 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import InternalServerError + +import controllers.console.explore.audio as audio_module +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, +) + + +def unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +@pytest.fixture +def installed_app(): + app = MagicMock() + app.app = MagicMock() + return app + + +@pytest.fixture +def audio_file(): + return (BytesIO(b"audio"), "audio.wav") + + +class TestChatAudioApi: + def setup_method(self): + self.api = audio_module.ChatAudioApi() + self.method = unwrap(self.api.post) + + def test_post_success(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + return_value={"text": "ok"}, + ), + ): + resp = self.method(installed_app) + + assert resp == {"text": "ok"} + + def test_app_unavailable(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + self.method(installed_app) + + def test_no_audio_uploaded(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(NoAudioUploadedError): + self.method(installed_app) + + def test_audio_too_large(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=AudioTooLargeServiceError("too big"), + ), + ): + with pytest.raises(AudioTooLargeError): + self.method(installed_app) + + def test_provider_quota_exceeded(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + self.method(installed_app) + + def test_unknown_exception(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + self.method(installed_app) + + def test_unsupported_audio_type(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=audio_module.UnsupportedAudioTypeServiceError(), + ), + ): + with pytest.raises(audio_module.UnsupportedAudioTypeError): + self.method(installed_app) + + def test_provider_not_support_speech_to_text(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError): + self.method(installed_app) + + def test_provider_not_initialized(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + self.method(installed_app) + + def test_model_currently_not_supported(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + self.method(installed_app) + + def test_invoke_error_asr(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=InvokeError("invoke failed"), + ), + ): + with pytest.raises(CompletionRequestError): + self.method(installed_app) + + +class TestChatTextApi: + def setup_method(self): + self.api = audio_module.ChatTextApi() + self.method = unwrap(self.api.post) + + def test_post_success(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"message_id": "m1", "text": "hello", "voice": "v1"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + return_value={"audio": "ok"}, + ), + ): + resp = self.method(installed_app) + + assert resp == {"audio": "ok"} + + def test_provider_not_initialized(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + self.method(installed_app) + + def test_model_not_supported(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + self.method(installed_app) + + def test_invoke_error(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=InvokeError("invoke failed"), + ), + ): + with pytest.raises(CompletionRequestError): + self.method(installed_app) + + def test_unknown_exception(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + self.method(installed_app) + + def test_app_unavailable_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + self.method(installed_app) + + def test_no_audio_uploaded_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(NoAudioUploadedError): + self.method(installed_app) + + def test_audio_too_large_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=AudioTooLargeServiceError("too big"), + ), + ): + with pytest.raises(AudioTooLargeError): + self.method(installed_app) + + def test_unsupported_audio_type_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=audio_module.UnsupportedAudioTypeServiceError(), + ), + ): + with pytest.raises(audio_module.UnsupportedAudioTypeError): + self.method(installed_app) + + def test_provider_not_support_speech_to_text_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError): + self.method(installed_app) + + def test_quota_exceeded_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + self.method(installed_app) diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py new file mode 100644 index 0000000000..0606219356 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -0,0 +1,100 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import controllers.console.explore.banner as banner_module + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestBannerApi: + def test_get_banners_with_requested_language(self, app): + api = banner_module.BannerApi() + method = unwrap(api.get) + + banner = MagicMock() + banner.id = "b1" + banner.content = {"text": "hello"} + banner.link = "https://example.com" + banner.sort = 1 + banner.status = "enabled" + banner.created_at = datetime(2024, 1, 1) + + query = MagicMock() + query.where.return_value = query + query.order_by.return_value = query + query.all.return_value = [banner] + + session = MagicMock() + session.query.return_value = query + + with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session): + result = method(api) + + assert result == [ + { + "id": "b1", + "content": {"text": "hello"}, + "link": "https://example.com", + "sort": 1, + "status": "enabled", + "created_at": "2024-01-01T00:00:00", + } + ] + + def test_get_banners_fallback_to_en_us(self, app): + api = banner_module.BannerApi() + method = unwrap(api.get) + + banner = MagicMock() + banner.id = "b2" + banner.content = {"text": "fallback"} + banner.link = None + banner.sort = 1 + banner.status = "enabled" + banner.created_at = None + + query = MagicMock() + query.where.return_value = query + query.order_by.return_value = query + query.all.side_effect = [ + [], + [banner], + ] + + session = MagicMock() + session.query.return_value = query + + with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session): + result = method(api) + + assert result == [ + { + "id": "b2", + "content": {"text": "fallback"}, + "link": None, + "sort": 1, + "status": "enabled", + "created_at": None, + } + ] + + def test_get_banners_default_language_en_us(self, app): + api = banner_module.BannerApi() + method = unwrap(api.get) + + query = MagicMock() + query.where.return_value = query + query.order_by.return_value = query + query.all.return_value = [] + + session = MagicMock() + session.query.return_value = query + + with app.test_request_context("/"), patch.object(banner_module.db, "session", session): + result = method(api) + + assert result == [] diff --git a/api/tests/unit_tests/controllers/console/explore/test_completion.py b/api/tests/unit_tests/controllers/console/explore/test_completion.py new file mode 100644 index 0000000000..1dd16f3c59 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_completion.py @@ -0,0 +1,459 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import InternalServerError + +import controllers.console.explore.completion as completion_module +from controllers.console.app.error import ( + ConversationCompletedError, +) +from controllers.console.explore.error import NotChatAppError, NotCompletionAppError +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from models import Account +from models.model import AppMode +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user(): + return MagicMock(spec=Account) + + +@pytest.fixture +def completion_app(): + return MagicMock(app=MagicMock(mode=AppMode.COMPLETION)) + + +@pytest.fixture +def chat_app(): + return MagicMock(app=MagicMock(mode=AppMode.CHAT)) + + +@pytest.fixture +def payload_data(): + return {"inputs": {}, "query": "hi"} + + +@pytest.fixture +def payload_patch(payload_data): + return patch.object( + type(completion_module.console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload_data, + ) + + +class TestCompletionApi: + def test_post_success(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + return_value={"ok": True}, + ), + patch.object( + completion_module.helper, + "compact_generate_response", + return_value=("ok", 200), + ), + ): + result = method(completion_app) + + assert result == ("ok", 200) + + def test_post_wrong_app_mode(self): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT)) + + with pytest.raises(NotCompletionAppError): + method(installed_app) + + def test_conversation_completed(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationCompletedError(), + ), + ): + with pytest.raises(ConversationCompletedError): + method(completion_app) + + def test_internal_error(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + method(completion_app) + + def test_conversation_not_exists(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(completion_module.NotFound): + method(completion_app) + + def test_app_unavailable(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(completion_module.AppUnavailableError): + method(completion_app) + + def test_provider_not_initialized(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(completion_module.ProviderNotInitializeError): + method(completion_app) + + def test_quota_exceeded(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.QuotaExceededError(), + ), + ): + with pytest.raises(completion_module.ProviderQuotaExceededError): + method(completion_app) + + def test_model_not_supported(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError): + method(completion_app) + + def test_invoke_error(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.InvokeError("invoke failed"), + ), + ): + with pytest.raises(completion_module.CompletionRequestError): + method(completion_app) + + +class TestCompletionStopApi: + def test_stop_success(self, completion_app, user): + api = completion_module.CompletionStopApi() + method = unwrap(api.post) + + user.id = "u1" + + with ( + patch.object(completion_module, "current_user", user), + patch.object(completion_module.AppTaskService, "stop_task"), + ): + resp, status = method(completion_app, "task-1") + + assert status == 200 + assert resp == {"result": "success"} + + def test_stop_wrong_app_mode(self): + api = completion_module.CompletionStopApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT)) + + with pytest.raises(NotCompletionAppError): + method(installed_app, "task") + + +class TestChatApi: + def test_post_success(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + return_value={"ok": True}, + ), + patch.object( + completion_module.helper, + "compact_generate_response", + return_value=("ok", 200), + ), + ): + result = method(chat_app) + + assert result == ("ok", 200) + + def test_post_not_chat_app(self): + api = completion_module.ChatApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION)) + + with pytest.raises(NotChatAppError): + method(installed_app) + + def test_rate_limit_error(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(chat_app) + + def test_conversation_completed_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationCompletedError(), + ), + ): + with pytest.raises(ConversationCompletedError): + method(chat_app) + + def test_conversation_not_exists_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(completion_module.NotFound): + method(chat_app) + + def test_app_unavailable_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(completion_module.AppUnavailableError): + method(chat_app) + + def test_provider_not_initialized_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(completion_module.ProviderNotInitializeError): + method(chat_app) + + def test_quota_exceeded_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.QuotaExceededError(), + ), + ): + with pytest.raises(completion_module.ProviderQuotaExceededError): + method(chat_app) + + def test_model_not_supported_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError): + method(chat_app) + + def test_invoke_error_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.InvokeError("invoke failed"), + ), + ): + with pytest.raises(completion_module.CompletionRequestError): + method(chat_app) + + def test_internal_error_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + method(chat_app) + + +class TestChatStopApi: + def test_stop_success(self, chat_app, user): + api = completion_module.ChatStopApi() + method = unwrap(api.post) + + user.id = "u1" + + with ( + patch.object(completion_module, "current_user", user), + patch.object(completion_module.AppTaskService, "stop_task"), + ): + resp, status = method(chat_app, "task-1") + + assert status == 200 + assert resp == {"result": "success"} + + def test_stop_not_chat_app(self): + api = completion_module.ChatStopApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION)) + + with pytest.raises(NotChatAppError): + method(installed_app, "task") diff --git a/api/tests/unit_tests/controllers/console/explore/test_conversation.py b/api/tests/unit_tests/controllers/console/explore/test_conversation.py new file mode 100644 index 0000000000..65cc209725 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_conversation.py @@ -0,0 +1,232 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +import controllers.console.explore.conversation as conversation_module +from controllers.console.explore.error import NotChatAppError +from models import Account +from models.model import AppMode +from services.errors.conversation import ( + ConversationNotExistsError, + LastConversationNotExistsError, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class FakeConversation: + def __init__(self, cid): + self.id = cid + self.name = "test" + self.inputs = {} + self.status = "normal" + self.introduction = "" + + +@pytest.fixture +def chat_app(): + app_model = MagicMock(mode=AppMode.CHAT, id="app-id") + return MagicMock(app=app_model) + + +@pytest.fixture +def non_chat_app(): + app_model = MagicMock(mode=AppMode.COMPLETION) + return MagicMock(app=app_model) + + +@pytest.fixture +def user(): + user = MagicMock(spec=Account) + user.id = "uid" + return user + + +@pytest.fixture(autouse=True) +def mock_db_and_session(): + with ( + patch.object( + conversation_module, + "db", + MagicMock(session=MagicMock(), engine=MagicMock()), + ), + patch( + "controllers.console.explore.conversation.Session", + MagicMock(), + ), + ): + yield + + +class TestConversationListApi: + def test_get_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationListApi() + method = unwrap(api.get) + + pagination = MagicMock( + limit=20, + has_more=False, + data=[FakeConversation("c1"), FakeConversation("c2")], + ) + + with ( + app.test_request_context("/?limit=20"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "pagination_by_last_id", + return_value=pagination, + ), + ): + result = method(chat_app) + + assert result["limit"] == 20 + assert result["has_more"] is False + assert len(result["data"]) == 2 + + def test_last_conversation_not_exists(self, app: Flask, chat_app, user): + api = conversation_module.ConversationListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "pagination_by_last_id", + side_effect=LastConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(chat_app) + + def test_wrong_app_mode(self, app: Flask, non_chat_app): + api = conversation_module.ConversationListApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(NotChatAppError): + method(non_chat_app) + + +class TestConversationApi: + def test_delete_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "delete", + ), + ): + result = method(chat_app, "cid") + + body, status = result + assert status == 204 + assert body["result"] == "success" + + def test_delete_not_found(self, app: Flask, chat_app, user): + api = conversation_module.ConversationApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "delete", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(chat_app, "cid") + + def test_delete_wrong_app_mode(self, app: Flask, non_chat_app): + api = conversation_module.ConversationApi() + method = unwrap(api.delete) + + with app.test_request_context("/"): + with pytest.raises(NotChatAppError): + method(non_chat_app, "cid") + + +class TestConversationRenameApi: + def test_rename_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationRenameApi() + method = unwrap(api.post) + + conversation = FakeConversation("cid") + + with ( + app.test_request_context("/", json={"name": "new"}), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "rename", + return_value=conversation, + ), + ): + result = method(chat_app, "cid") + + assert result["id"] == "cid" + + def test_rename_not_found(self, app: Flask, chat_app, user): + api = conversation_module.ConversationRenameApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "new"}), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "rename", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(chat_app, "cid") + + +class TestConversationPinApi: + def test_pin_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationPinApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "pin", + ), + ): + result = method(chat_app, "cid") + + assert result == {"result": "success"} + + +class TestConversationUnPinApi: + def test_unpin_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationUnPinApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "unpin", + ), + ): + result = method(chat_app, "cid") + + assert result == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py new file mode 100644 index 0000000000..3983a6a97e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py @@ -0,0 +1,363 @@ +from datetime import datetime +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +import controllers.console.explore.installed_app as module + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_id(): + return "t1" + + +@pytest.fixture +def current_user(tenant_id): + user = MagicMock() + user.id = "u1" + user.current_tenant = MagicMock(id=tenant_id) + return user + + +@pytest.fixture +def installed_app(): + app = MagicMock() + app.id = "ia1" + app.app = MagicMock(id="a1") + app.app_owner_tenant_id = "t2" + app.is_pinned = False + app.last_used_at = datetime(2024, 1, 1) + return app + + +@pytest.fixture +def payload_patch(): + def _patch(payload): + return patch.object( + type(module.console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return _patch + + +class TestInstalledAppsListApi: + def test_get_installed_apps(self, app, current_user, tenant_id, installed_app): + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), + ), + ): + result = method(api) + + assert "installed_apps" in result + assert result["installed_apps"][0]["editable"] is True + assert result["installed_apps"][0]["uninstallable"] is False + + def test_get_installed_apps_with_app_id_filter(self, app, current_user, tenant_id): + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [] + + with ( + app.test_request_context("/?app_id=a1"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="member"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), + ), + ): + result = method(api) + + assert result == {"installed_apps": []} + + def test_get_installed_apps_with_webapp_auth_enabled(self, app, current_user, tenant_id, installed_app): + """Test filtering when webapp_auth is enabled.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + mock_webapp_setting = MagicMock() + mock_webapp_setting.access_mode = "restricted" + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=True)), + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_get_app_access_mode_by_id", + return_value={"a1": mock_webapp_setting}, + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_is_user_allowed_to_access_webapps", + return_value={"a1": True}, + ), + ): + result = method(api) + + assert len(result["installed_apps"]) == 1 + + def test_get_installed_apps_with_webapp_auth_user_denied(self, app, current_user, tenant_id, installed_app): + """Test filtering when user doesn't have access.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + mock_webapp_setting = MagicMock() + mock_webapp_setting.access_mode = "restricted" + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="member"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=True)), + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_get_app_access_mode_by_id", + return_value={"a1": mock_webapp_setting}, + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_is_user_allowed_to_access_webapps", + return_value={"a1": False}, + ), + ): + result = method(api) + + assert result["installed_apps"] == [] + + def test_get_installed_apps_with_sso_verified_access(self, app, current_user, tenant_id, installed_app): + """Test that sso_verified access mode apps are skipped in filtering.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + mock_webapp_setting = MagicMock() + mock_webapp_setting.access_mode = "sso_verified" + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=True)), + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_get_app_access_mode_by_id", + return_value={"a1": mock_webapp_setting}, + ), + ): + result = method(api) + + assert len(result["installed_apps"]) == 0 + + def test_get_installed_apps_filters_null_apps(self, app, current_user, tenant_id): + """Test that installed apps with null app are filtered out.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + installed_app_with_null = MagicMock() + installed_app_with_null.app = None + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app_with_null] + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), + ), + ): + result = method(api) + + assert result["installed_apps"] == [] + + def test_get_installed_apps_current_tenant_none(self, app, tenant_id, installed_app): + """Test error when current_user.current_tenant is None.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + current_user = MagicMock() + current_user.current_tenant = None + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + ): + with pytest.raises(ValueError, match="current_user.current_tenant must not be None"): + method(api) + + +class TestInstalledAppsCreateApi: + def test_post_success(self, app, tenant_id, payload_patch): + api = module.InstalledAppsListApi() + method = unwrap(api.post) + + recommended = MagicMock() + recommended.install_count = 0 + + app_entity = MagicMock() + app_entity.id = "a1" + app_entity.is_public = True + app_entity.tenant_id = "t2" + + session = MagicMock() + session.query.return_value.where.return_value.first.side_effect = [ + recommended, + app_entity, + None, + ] + + with ( + app.test_request_context("/", json={"app_id": "a1"}), + payload_patch({"app_id": "a1"}), + patch.object(module.db, "session", session), + patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), + ): + result = method(api) + + assert result == {"message": "App installed successfully"} + assert recommended.install_count == 1 + + def test_post_recommended_not_found(self, app, payload_patch): + api = module.InstalledAppsListApi() + method = unwrap(api.post) + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + + with ( + app.test_request_context("/", json={"app_id": "a1"}), + payload_patch({"app_id": "a1"}), + patch.object(module.db, "session", session), + ): + with pytest.raises(NotFound): + method(api) + + def test_post_app_not_public(self, app, tenant_id, payload_patch): + api = module.InstalledAppsListApi() + method = unwrap(api.post) + + recommended = MagicMock() + app_entity = MagicMock(is_public=False) + + session = MagicMock() + session.query.return_value.where.return_value.first.side_effect = [ + recommended, + app_entity, + ] + + with ( + app.test_request_context("/", json={"app_id": "a1"}), + payload_patch({"app_id": "a1"}), + patch.object(module.db, "session", session), + patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestInstalledAppApi: + def test_delete_success(self, tenant_id, installed_app): + api = module.InstalledAppApi() + method = unwrap(api.delete) + + with ( + patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), + patch.object(module.db, "session"), + ): + resp, status = method(installed_app) + + assert status == 204 + assert resp["result"] == "success" + + def test_delete_owned_by_current_tenant(self, tenant_id): + api = module.InstalledAppApi() + method = unwrap(api.delete) + + installed_app = MagicMock(app_owner_tenant_id=tenant_id) + + with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)): + with pytest.raises(BadRequest): + method(installed_app) + + def test_patch_update_pin(self, app, payload_patch, installed_app): + api = module.InstalledAppApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/", json={"is_pinned": True}), + payload_patch({"is_pinned": True}), + patch.object(module.db, "session"), + ): + result = method(installed_app) + + assert installed_app.is_pinned is True + assert result["result"] == "success" + + def test_patch_no_change(self, app, payload_patch, installed_app): + api = module.InstalledAppApi() + method = unwrap(api.patch) + + with app.test_request_context("/", json={}), payload_patch({}), patch.object(module.db, "session"): + result = method(installed_app) + + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py new file mode 100644 index 0000000000..c3a6522e6d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -0,0 +1,552 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import InternalServerError, NotFound + +import controllers.console.explore.message as module +from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.explore.error import ( + AppSuggestedQuestionsAfterAnswerDisabledError, + NotChatAppError, + NotCompletionAppError, +) +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import ( + FirstMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) + + +def unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def make_message(): + msg = MagicMock() + msg.id = "m1" + msg.conversation_id = "11111111-1111-1111-1111-111111111111" + msg.parent_message_id = None + msg.inputs = {} + msg.query = "hello" + msg.re_sign_file_url_answer = "" + msg.user_feedback = MagicMock(rating=None) + msg.status = "success" + msg.error = None + return msg + + +class TestMessageListApi: + def test_get_success(self, app): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + pagination = MagicMock( + limit=20, + has_more=False, + data=[make_message(), make_message()], + ) + + with ( + app.test_request_context( + "/", + query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "pagination_by_first_id", + return_value=pagination, + ), + ): + result = method(installed_app) + + assert result["limit"] == 20 + assert result["has_more"] is False + assert len(result["data"]) == 2 + + def test_get_not_chat_app(self): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotChatAppError): + method(installed_app) + + def test_conversation_not_exists(self, app): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + app.test_request_context( + "/", + query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "pagination_by_first_id", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app) + + def test_first_message_not_exists(self, app): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + app.test_request_context( + "/", + query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "pagination_by_first_id", + side_effect=FirstMessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app) + + +class TestMessageFeedbackApi: + def test_post_success(self, app): + api = module.MessageFeedbackApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock() + + with ( + app.test_request_context("/", json={"rating": "like"}), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "create_feedback", + ), + ): + result = method(installed_app, "mid") + + assert result["result"] == "success" + + def test_message_not_exists(self, app): + api = module.MessageFeedbackApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock() + + with ( + app.test_request_context("/", json={}), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "create_feedback", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + +class TestMessageMoreLikeThisApi: + def test_get_success(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + return_value={"ok": True}, + ), + patch.object( + module.helper, + "compact_generate_response", + return_value=("ok", 200), + ), + ): + resp = method(installed_app, "mid") + + assert resp == ("ok", 200) + + def test_not_completion_app(self): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotCompletionAppError): + method(installed_app, "mid") + + def test_more_like_this_disabled(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=module.MoreLikeThisDisabledError(), + ), + ): + with pytest.raises(AppMoreLikeThisDisabledError): + method(installed_app, "mid") + + def test_message_not_exists_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + def test_provider_not_init_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(installed_app, "mid") + + def test_quota_exceeded_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(installed_app, "mid") + + def test_model_not_support_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(installed_app, "mid") + + def test_invoke_error_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(installed_app, "mid") + + def test_unexpected_error_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=Exception("unexpected"), + ), + ): + with pytest.raises(InternalServerError): + method(installed_app, "mid") + + +class TestMessageSuggestedQuestionApi: + def test_get_success(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ), + ): + result = method(installed_app, "mid") + + assert result["data"] == ["q1", "q2"] + + def test_not_chat_app(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotChatAppError): + method(installed_app, "mid") + + def test_disabled(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=SuggestedQuestionsAfterAnswerDisabledError(), + ), + ): + with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError): + method(installed_app, "mid") + + def test_message_not_exists_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + def test_conversation_not_exists_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + def test_provider_not_init_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(installed_app, "mid") + + def test_quota_exceeded_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(installed_app, "mid") + + def test_model_not_support_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(installed_app, "mid") + + def test_invoke_error_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(installed_app, "mid") + + def test_unexpected_error_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=Exception("unexpected"), + ), + ): + with pytest.raises(InternalServerError): + method(installed_app, "mid") diff --git a/api/tests/unit_tests/controllers/console/explore/test_parameter.py b/api/tests/unit_tests/controllers/console/explore/test_parameter.py new file mode 100644 index 0000000000..7aaecbff14 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_parameter.py @@ -0,0 +1,140 @@ +from unittest.mock import MagicMock, patch + +import pytest + +import controllers.console.explore.parameter as module +from controllers.console.app.error import AppUnavailableError +from models.model import AppMode + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAppParameterApi: + def test_get_app_none(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + installed_app = MagicMock(app=None) + + with pytest.raises(AppUnavailableError): + method(installed_app) + + def test_get_advanced_chat_workflow(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + workflow = MagicMock() + workflow.features_dict = {"f": "v"} + workflow.user_input_form.return_value = [{"name": "x"}] + + app = MagicMock( + mode=AppMode.ADVANCED_CHAT, + workflow=workflow, + ) + + installed_app = MagicMock(app=app) + + with ( + patch.object( + module, + "get_parameters_from_feature_dict", + return_value={"any": "thing"}, + ), + patch.object( + module.fields.Parameters, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: {"ok": True}), + ), + ): + result = method(installed_app) + + assert result == {"ok": True} + + def test_get_advanced_chat_workflow_missing(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + app = MagicMock( + mode=AppMode.ADVANCED_CHAT, + workflow=None, + ) + + installed_app = MagicMock(app=app) + + with pytest.raises(AppUnavailableError): + method(installed_app) + + def test_get_non_workflow_app(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + app_model_config = MagicMock() + app_model_config.to_dict.return_value = {"user_input_form": [{"name": "y"}]} + + app = MagicMock( + mode=AppMode.CHAT, + app_model_config=app_model_config, + ) + + installed_app = MagicMock(app=app) + + with ( + patch.object( + module, + "get_parameters_from_feature_dict", + return_value={"whatever": 123}, + ), + patch.object( + module.fields.Parameters, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: {"ok": True}), + ), + ): + result = method(installed_app) + + assert result == {"ok": True} + + def test_get_non_workflow_missing_config(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + app = MagicMock( + mode=AppMode.CHAT, + app_model_config=None, + ) + + installed_app = MagicMock(app=app) + + with pytest.raises(AppUnavailableError): + method(installed_app) + + +class TestExploreAppMetaApi: + def test_get_meta_success(self): + api = module.ExploreAppMetaApi() + method = unwrap(api.get) + + app = MagicMock() + installed_app = MagicMock(app=app) + + with patch.object( + module.AppService, + "get_app_meta", + return_value={"meta": "ok"}, + ): + result = method(installed_app) + + assert result == {"meta": "ok"} + + def test_get_meta_app_missing(self): + api = module.ExploreAppMetaApi() + method = unwrap(api.get) + + installed_app = MagicMock(app=None) + + with pytest.raises(ValueError): + method(installed_app) diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py new file mode 100644 index 0000000000..02c7507ea7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -0,0 +1,92 @@ +from unittest.mock import MagicMock, patch + +import controllers.console.explore.recommended_app as module + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestRecommendedAppListApi: + def test_get_with_language_param(self, app): + api = module.RecommendedAppListApi() + method = unwrap(api.get) + + result_data = {"recommended_apps": [], "categories": []} + + with ( + app.test_request_context("/", query_string={"language": "en-US"}), + patch.object(module, "current_user", MagicMock(interface_language="fr-FR")), + patch.object( + module.RecommendedAppService, + "get_recommended_apps_and_categories", + return_value=result_data, + ) as service_mock, + ): + result = method(api) + + service_mock.assert_called_once_with("en-US") + assert result == result_data + + def test_get_fallback_to_user_language(self, app): + api = module.RecommendedAppListApi() + method = unwrap(api.get) + + result_data = {"recommended_apps": [], "categories": []} + + with ( + app.test_request_context("/", query_string={"language": "invalid"}), + patch.object(module, "current_user", MagicMock(interface_language="fr-FR")), + patch.object( + module.RecommendedAppService, + "get_recommended_apps_and_categories", + return_value=result_data, + ) as service_mock, + ): + result = method(api) + + service_mock.assert_called_once_with("fr-FR") + assert result == result_data + + def test_get_fallback_to_default_language(self, app): + api = module.RecommendedAppListApi() + method = unwrap(api.get) + + result_data = {"recommended_apps": [], "categories": []} + + with ( + app.test_request_context("/"), + patch.object(module, "current_user", MagicMock(interface_language=None)), + patch.object( + module.RecommendedAppService, + "get_recommended_apps_and_categories", + return_value=result_data, + ) as service_mock, + ): + result = method(api) + + service_mock.assert_called_once_with(module.languages[0]) + assert result == result_data + + +class TestRecommendedAppApi: + def test_get_success(self, app): + api = module.RecommendedAppApi() + method = unwrap(api.get) + + result_data = {"id": "app1"} + + with ( + app.test_request_context("/"), + patch.object( + module.RecommendedAppService, + "get_recommend_app_detail", + return_value=result_data, + ) as service_mock, + ): + result = method(api, "11111111-1111-1111-1111-111111111111") + + service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") + assert result == result_data diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py new file mode 100644 index 0000000000..bb7cdd55c4 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -0,0 +1,154 @@ +from unittest.mock import MagicMock, PropertyMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import NotFound + +import controllers.console.explore.saved_message as module +from controllers.console.explore.error import NotCompletionAppError +from services.errors.message import MessageNotExistsError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def make_saved_message(): + msg = MagicMock() + msg.id = str(uuid4()) + msg.message_id = str(uuid4()) + msg.app_id = str(uuid4()) + msg.inputs = {} + msg.query = "hello" + msg.answer = "world" + msg.user_feedback = MagicMock(rating="like") + msg.created_at = None + return msg + + +@pytest.fixture +def payload_patch(): + def _patch(payload): + return patch.object( + type(module.console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return _patch + + +class TestSavedMessageListApi: + def test_get_success(self, app): + api = module.SavedMessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + pagination = MagicMock( + limit=20, + has_more=False, + data=[make_saved_message(), make_saved_message()], + ) + + with ( + app.test_request_context("/", query_string={}), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.SavedMessageService, + "pagination_by_last_id", + return_value=pagination, + ), + ): + result = method(installed_app) + + assert result["limit"] == 20 + assert result["has_more"] is False + assert len(result["data"]) == 2 + + def test_get_not_completion_app(self): + api = module.SavedMessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotCompletionAppError): + method(installed_app) + + def test_post_success(self, app, payload_patch): + api = module.SavedMessageListApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + payload = {"message_id": str(uuid4())} + + with ( + app.test_request_context("/", json=payload), + payload_patch(payload), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object(module.SavedMessageService, "save") as save_mock, + ): + result = method(installed_app) + + save_mock.assert_called_once() + assert result == {"result": "success"} + + def test_post_message_not_exists(self, app, payload_patch): + api = module.SavedMessageListApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + payload = {"message_id": str(uuid4())} + + with ( + app.test_request_context("/", json=payload), + payload_patch(payload), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.SavedMessageService, + "save", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app) + + +class TestSavedMessageApi: + def test_delete_success(self): + api = module.SavedMessageApi() + method = unwrap(api.delete) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object(module.SavedMessageService, "delete") as delete_mock, + ): + result, status = method(installed_app, str(uuid4())) + + delete_mock.assert_called_once() + assert status == 204 + assert result == {"result": "success"} + + def test_delete_not_completion_app(self): + api = module.SavedMessageApi() + method = unwrap(api.delete) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotCompletionAppError): + method(installed_app, str(uuid4())) diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py new file mode 100644 index 0000000000..d85114c8fb --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -0,0 +1,1101 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import controllers.console.explore.trial as module +from controllers.console.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.explore.error import ( + NotChatAppError, + NotCompletionAppError, + NotWorkflowAppError, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from models import Account +from models.account import TenantStatus +from models.model import AppMode +from services.errors.conversation import ConversationNotExistsError +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def account(): + acc = MagicMock(spec=Account) + acc.id = "u1" + return acc + + +@pytest.fixture +def trial_app_chat(): + app = MagicMock() + app.id = "a-chat" + app.mode = AppMode.CHAT + return app + + +@pytest.fixture +def trial_app_completion(): + app = MagicMock() + app.id = "a-comp" + app.mode = AppMode.COMPLETION + return app + + +@pytest.fixture +def trial_app_workflow(): + app = MagicMock() + app.id = "a-workflow" + app.mode = AppMode.WORKFLOW + return app + + +@pytest.fixture +def valid_parameters(): + return { + "user_input_form": [], + "system_parameters": {}, + "suggested_questions": {}, + "suggested_questions_after_answer": {}, + "speech_to_text": {}, + "text_to_speech": {}, + "retriever_resource": {}, + "annotation_reply": {}, + "more_like_this": {}, + "sensitive_word_avoidance": {}, + "file_upload": {}, + } + + +class TestTrialAppWorkflowRunApi: + def test_not_workflow_app(self, app): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with app.test_request_context("/"): + with pytest.raises(NotWorkflowAppError): + method(MagicMock(mode=AppMode.CHAT)) + + def test_success(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(trial_app_workflow) + + assert result is not None + + def test_workflow_provider_not_init(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(trial_app_workflow) + + def test_workflow_quota_exceeded(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(trial_app_workflow) + + def test_workflow_model_not_support(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(trial_app_workflow) + + def test_workflow_invoke_error(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(trial_app_workflow) + + def test_workflow_rate_limit_error(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("test"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(trial_app_workflow) + + def test_workflow_value_error(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "files": []}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ValueError("test error"), + ), + ): + with pytest.raises(ValueError): + method(trial_app_workflow) + + def test_workflow_generic_exception(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "files": []}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=RuntimeError("unexpected error"), + ), + ): + with pytest.raises(InternalServerError): + method(trial_app_workflow) + + +class TestTrialChatApi: + def test_not_chat_app(self, app): + api = module.TrialChatApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={"inputs": {}, "query": "hi"}): + with pytest.raises(NotChatAppError): + method(api, MagicMock(mode="completion")) + + def test_success(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_chat) + + assert result is not None + + def test_chat_conversation_not_exists(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(api, trial_app_chat) + + def test_chat_conversation_completed(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.conversation.ConversationCompletedError(), + ), + ): + with pytest.raises(ConversationCompletedError): + method(api, trial_app_chat) + + def test_chat_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + method(api, trial_app_chat) + + def test_chat_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_chat_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + def test_chat_model_not_support(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(api, trial_app_chat) + + def test_chat_invoke_error(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_chat) + + def test_chat_rate_limit_error(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("test"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, trial_app_chat) + + def test_chat_value_error(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ValueError("test error"), + ), + ): + with pytest.raises(ValueError): + method(api, trial_app_chat) + + def test_chat_generic_exception(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=RuntimeError("unexpected error"), + ), + ): + with pytest.raises(InternalServerError): + method(api, trial_app_chat) + + +class TestTrialCompletionApi: + def test_not_completion_app(self, app): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={"inputs": {}, "query": ""}): + with pytest.raises(NotCompletionAppError): + method(api, MagicMock(mode=AppMode.CHAT)) + + def test_success(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_completion) + + assert result is not None + + def test_completion_app_config_broken(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + method(api, trial_app_completion) + + def test_completion_provider_not_init(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_completion) + + def test_completion_quota_exceeded(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_completion) + + def test_completion_model_not_support(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(api, trial_app_completion) + + def test_completion_invoke_error(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_completion) + + def test_completion_rate_limit_error(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("test"), + ), + ): + with pytest.raises(InternalServerError): + method(api, trial_app_completion) + + def test_completion_value_error(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ValueError("test error"), + ), + ): + with pytest.raises(ValueError): + method(api, trial_app_completion) + + def test_completion_generic_exception(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=RuntimeError("unexpected error"), + ), + ): + with pytest.raises(InternalServerError): + method(api, trial_app_completion) + + +class TestTrialMessageSuggestedQuestionApi: + def test_not_chat_app(self, app): + api = module.TrialMessageSuggestedQuestionApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(NotChatAppError): + method(api, MagicMock(mode="completion"), str(uuid4())) + + def test_success(self, app, trial_app_chat, account): + api = module.TrialMessageSuggestedQuestionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object(module, "current_user", account), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ), + ): + result = method(api, trial_app_chat, str(uuid4())) + + assert result == {"data": ["q1", "q2"]} + + def test_conversation_not_exists(self, app, trial_app_chat, account): + api = module.TrialMessageSuggestedQuestionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object(module, "current_user", account), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(api, trial_app_chat, str(uuid4())) + + +class TestTrialAppParameterApi: + def test_app_unavailable(self): + api = module.TrialAppParameterApi() + method = unwrap(api.get) + + with pytest.raises(AppUnavailableError): + method(api, None) + + def test_success_non_workflow(self, valid_parameters): + api = module.TrialAppParameterApi() + method = unwrap(api.get) + + app_model = MagicMock( + mode=AppMode.CHAT, + app_model_config=MagicMock(to_dict=lambda: {"user_input_form": []}), + ) + + with ( + patch.object( + module, + "get_parameters_from_feature_dict", + return_value=valid_parameters, + ), + patch.object( + module.ParametersResponse, + "model_validate", + return_value=MagicMock(model_dump=lambda mode=None: {"ok": True}), + ), + ): + result = method(api, app_model) + + assert result == {"ok": True} + + +class TestTrialChatAudioApi: + def test_success(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_asr", return_value={"text": "hello"}), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_chat) + + assert result == {"text": "hello"} + + def test_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(module.AppUnavailableError): + method(api, trial_app_chat) + + def test_no_audio_uploaded(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(module.NoAudioUploadedError): + method(api, trial_app_chat) + + def test_audio_too_large(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.AudioTooLargeServiceError("Too large"), + ), + ): + with pytest.raises(module.AudioTooLargeError): + method(api, trial_app_chat) + + def test_unsupported_audio_type(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.UnsupportedAudioTypeServiceError(), + ), + ): + with pytest.raises(module.UnsupportedAudioTypeError): + method(api, trial_app_chat) + + def test_provider_not_support_tts(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(module.ProviderNotSupportSpeechToTextError): + method(api, trial_app_chat) + + def test_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_asr", side_effect=ProviderTokenNotInitError("test")), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_asr", side_effect=QuotaExceededError()), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + +class TestTrialChatTextApi: + def test_success(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", return_value={"audio": "base64_data"}), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_chat) + + assert result == {"audio": "base64_data"} + + def test_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(module.AppUnavailableError): + method(api, trial_app_chat) + + def test_provider_not_support(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(module.ProviderNotSupportSpeechToTextError): + method(api, trial_app_chat) + + def test_audio_too_large(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.AudioTooLargeServiceError("Too large"), + ), + ): + with pytest.raises(module.AudioTooLargeError): + method(api, trial_app_chat) + + def test_no_audio_uploaded(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(module.NoAudioUploadedError): + method(api, trial_app_chat) + + def test_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=ProviderTokenNotInitError("test")), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=QuotaExceededError()), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + def test_model_not_support(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=ModelCurrentlyNotSupportError()), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(api, trial_app_chat) + + def test_invoke_error(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=InvokeError("test error")), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_chat) + + +class TestTrialAppWorkflowTaskStopApi: + def test_not_workflow_app(self, app, trial_app_chat): + api = module.TrialAppWorkflowTaskStopApi() + method = unwrap(api.post) + + with app.test_request_context("/"): + with pytest.raises(NotWorkflowAppError): + method(trial_app_chat, str(uuid4())) + + def test_success(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowTaskStopApi() + method = unwrap(api.post) + + task_id = str(uuid4()) + with ( + app.test_request_context("/"), + patch.object(module, "current_user", account), + patch.object(module.AppQueueManager, "set_stop_flag_no_user_check") as mock_set_flag, + patch.object(module.GraphEngineManager, "send_stop_command") as mock_send_cmd, + ): + result = method(trial_app_workflow, task_id) + + assert result == {"result": "success"} + mock_set_flag.assert_called_once_with(task_id) + mock_send_cmd.assert_called_once_with(task_id) + + +class TestTrialSitApi: + def test_no_site(self, app): + api = module.TrialSitApi() + method = unwrap(api.get) + app_model = MagicMock() + app_model.id = "a1" + + with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: + mock_query.return_value.where.return_value.first.return_value = None + with pytest.raises(Forbidden): + method(api, app_model) + + def test_archived_tenant(self, app): + api = module.TrialSitApi() + method = unwrap(api.get) + + site = MagicMock() + app_model = MagicMock() + app_model.id = "a1" + app_model.tenant = MagicMock() + app_model.tenant.status = TenantStatus.ARCHIVE + + with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: + mock_query.return_value.where.return_value.first.return_value = site + with pytest.raises(Forbidden): + method(api, app_model) + + def test_success(self, app): + api = module.TrialSitApi() + method = unwrap(api.get) + + site = MagicMock() + app_model = MagicMock() + app_model.id = "a1" + app_model.tenant = MagicMock() + app_model.tenant.status = TenantStatus.NORMAL + + with ( + app.test_request_context("/"), + patch.object(module.db.session, "query") as mock_query, + patch.object(module.SiteResponse, "model_validate") as mock_validate, + ): + mock_query.return_value.where.return_value.first.return_value = site + mock_validate_result = MagicMock() + mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"} + mock_validate.return_value = mock_validate_result + result = method(api, app_model) + + assert result == {"name": "test", "icon": "icon"} + + +class TestTrialChatAudioApiExceptionHandlers: + def test_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + def test_invoke_error(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_chat) + + +class TestTrialChatTextApiExceptionHandlers: + def test_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(module.AppUnavailableError): + method(api, trial_app_chat) + + def test_unsupported_audio_type(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.UnsupportedAudioTypeServiceError("test"), + ), + ): + with pytest.raises(module.UnsupportedAudioTypeError): + method(api, trial_app_chat) diff --git a/api/tests/unit_tests/controllers/console/explore/test_workflow.py b/api/tests/unit_tests/controllers/console/explore/test_workflow.py new file mode 100644 index 0000000000..445f887fd3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_workflow.py @@ -0,0 +1,151 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import InternalServerError + +from controllers.console.explore.error import NotWorkflowAppError +from controllers.console.explore.workflow import ( + InstalledAppWorkflowRunApi, + InstalledAppWorkflowTaskStopApi, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from models.model import AppMode +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +@pytest.fixture +def user(): + return MagicMock() + + +@pytest.fixture +def workflow_app(): + app = MagicMock() + app.mode = AppMode.WORKFLOW + return app + + +@pytest.fixture +def installed_workflow_app(workflow_app): + return MagicMock(app=workflow_app) + + +@pytest.fixture +def non_workflow_installed_app(): + app = MagicMock() + app.mode = AppMode.CHAT + return MagicMock(app=app) + + +@pytest.fixture +def payload(): + return {"inputs": {"a": 1}} + + +class TestInstalledAppWorkflowRunApi: + def test_not_workflow_app(self, app, non_workflow_installed_app): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(NotWorkflowAppError): + method(non_workflow_installed_app) + + def test_success(self, app, installed_workflow_app, user, payload): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(user, None), + ), + patch( + "controllers.console.explore.workflow.AppGenerateService.generate", + return_value=MagicMock(), + ) as generate_mock, + ): + result = method(installed_workflow_app) + + generate_mock.assert_called_once() + assert result is not None + + def test_rate_limit_error(self, app, installed_workflow_app, user, payload): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(user, None), + ), + patch( + "controllers.console.explore.workflow.AppGenerateService.generate", + side_effect=InvokeRateLimitError("rate limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(installed_workflow_app) + + def test_unexpected_exception(self, app, installed_workflow_app, user, payload): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(user, None), + ), + patch( + "controllers.console.explore.workflow.AppGenerateService.generate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + method(installed_workflow_app) + + +class TestInstalledAppWorkflowTaskStopApi: + def test_not_workflow_app(self, non_workflow_installed_app): + api = InstalledAppWorkflowTaskStopApi() + method = unwrap(api.post) + + with pytest.raises(NotWorkflowAppError): + method(non_workflow_installed_app, "task-1") + + def test_success(self, installed_workflow_app): + api = InstalledAppWorkflowTaskStopApi() + method = unwrap(api.post) + + with ( + patch("controllers.console.explore.workflow.AppQueueManager.set_stop_flag_no_user_check") as stop_flag, + patch("controllers.console.explore.workflow.GraphEngineManager.send_stop_command") as send_stop, + ): + result = method(installed_workflow_app, "task-1") + + stop_flag.assert_called_once_with("task-1") + send_stop.assert_called_once_with("task-1") + assert result == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/explore/test_wraps.py b/api/tests/unit_tests/controllers/console/explore/test_wraps.py new file mode 100644 index 0000000000..67e7a32591 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_wraps.py @@ -0,0 +1,244 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.console.explore.error import ( + AppAccessDeniedError, + TrialAppLimitExceeded, + TrialAppNotAllowed, +) +from controllers.console.explore.wraps import ( + InstalledAppResource, + TrialAppResource, + installed_app_required, + trial_app_required, + trial_feature_enable, + user_allowed_to_access_app, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def test_installed_app_required_not_found(): + @installed_app_required + def view(installed_app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch("controllers.console.explore.wraps.db.session.query") as q, + ): + q.return_value.where.return_value.first.return_value = None + + with pytest.raises(NotFound): + view("app-id") + + +def test_installed_app_required_app_deleted(): + installed_app = MagicMock(app=None) + + @installed_app_required + def view(installed_app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.delete"), + patch("controllers.console.explore.wraps.db.session.commit"), + ): + q.return_value.where.return_value.first.return_value = installed_app + + with pytest.raises(NotFound): + view("app-id") + + +def test_installed_app_required_success(): + installed_app = MagicMock(app=MagicMock()) + + @installed_app_required + def view(installed_app): + return installed_app + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch("controllers.console.explore.wraps.db.session.query") as q, + ): + q.return_value.where.return_value.first.return_value = installed_app + + result = view("app-id") + assert result == installed_app + + +def test_user_allowed_to_access_app_denied(): + installed_app = MagicMock(app_id="app-1") + + @user_allowed_to_access_app + def view(installed_app): + return "ok" + + feature = MagicMock() + feature.webapp_auth.enabled = True + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=feature, + ), + patch( + "controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", + return_value=False, + ), + ): + with pytest.raises(AppAccessDeniedError): + view(installed_app) + + +def test_user_allowed_to_access_app_success(): + installed_app = MagicMock(app_id="app-1") + + @user_allowed_to_access_app + def view(installed_app): + return "ok" + + feature = MagicMock() + feature.webapp_auth.enabled = True + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=feature, + ), + patch( + "controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", + return_value=True, + ), + ): + assert view(installed_app) == "ok" + + +def test_trial_app_required_not_allowed(): + @trial_app_required + def view(app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch("controllers.console.explore.wraps.db.session.query") as q, + ): + q.return_value.where.return_value.first.return_value = None + + with pytest.raises(TrialAppNotAllowed): + view("app-id") + + +def test_trial_app_required_limit_exceeded(): + trial_app = MagicMock(trial_limit=1, app=MagicMock()) + record = MagicMock(count=1) + + @trial_app_required + def view(app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch("controllers.console.explore.wraps.db.session.query") as q, + ): + q.return_value.where.return_value.first.side_effect = [ + trial_app, + record, + ] + + with pytest.raises(TrialAppLimitExceeded): + view("app-id") + + +def test_trial_app_required_success(): + trial_app = MagicMock(trial_limit=2, app=MagicMock()) + record = MagicMock(count=1) + + @trial_app_required + def view(app): + return app + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch("controllers.console.explore.wraps.db.session.query") as q, + ): + q.return_value.where.return_value.first.side_effect = [ + trial_app, + record, + ] + + result = view("app-id") + assert result == trial_app.app + + +def test_trial_feature_enable_disabled(): + @trial_feature_enable + def view(): + return "ok" + + features = MagicMock(enable_trial_app=False) + + with patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=features, + ): + with pytest.raises(Forbidden): + view() + + +def test_trial_feature_enable_enabled(): + @trial_feature_enable + def view(): + return "ok" + + features = MagicMock(enable_trial_app=True) + + with patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=features, + ): + assert view() == "ok" + + +def test_installed_app_resource_decorators(): + decorators = InstalledAppResource.method_decorators + assert len(decorators) == 4 + + +def test_trial_app_resource_decorators(): + decorators = TrialAppResource.method_decorators + assert len(decorators) == 3 diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py new file mode 100644 index 0000000000..769edc8d1c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -0,0 +1,278 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console.tag.tags import ( + TagBindingCreateApi, + TagBindingDeleteApi, + TagListApi, + TagUpdateDeleteApi, +) + + +def unwrap(func): + """ + Recursively unwrap decorated functions. + """ + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_tag") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def admin_user(): + return MagicMock( + id="user-1", + has_edit_permission=True, + is_dataset_editor=True, + ) + + +@pytest.fixture +def readonly_user(): + return MagicMock( + id="user-2", + has_edit_permission=False, + is_dataset_editor=False, + ) + + +@pytest.fixture +def tag(): + tag = MagicMock() + tag.id = "tag-1" + tag.name = "test-tag" + tag.type = "knowledge" + return tag + + +@pytest.fixture +def payload_patch(): + def _patch(payload): + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return _patch + + +class TestTagListApi: + def test_get_success(self, app): + api = TagListApi() + method = unwrap(api.get) + + with app.test_request_context("/?type=knowledge"): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.tag.tags.TagService.get_tags", + return_value=[{"id": "1", "name": "tag"}], + ), + ): + result, status = method(api) + + assert status == 200 + assert isinstance(result, list) + + def test_post_success(self, app, admin_user, tag, payload_patch): + api = TagListApi() + method = unwrap(api.post) + + payload = {"name": "test-tag", "type": "knowledge"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch( + "controllers.console.tag.tags.TagService.save_tags", + return_value=tag, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["name"] == "test-tag" + + def test_post_forbidden(self, app, readonly_user, payload_patch): + api = TagListApi() + method = unwrap(api.post) + + payload = {"name": "x"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch(payload), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestTagUpdateDeleteApi: + def test_patch_success(self, app, admin_user, tag, payload_patch): + api = TagUpdateDeleteApi() + method = unwrap(api.patch) + + payload = {"name": "updated", "type": "knowledge"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch( + "controllers.console.tag.tags.TagService.update_tags", + return_value=tag, + ), + patch( + "controllers.console.tag.tags.TagService.get_tag_binding_count", + return_value=3, + ), + ): + result, status = method(api, "tag-1") + + assert status == 200 + assert result["binding_count"] == 3 + + def test_patch_forbidden(self, app, readonly_user, payload_patch): + api = TagUpdateDeleteApi() + method = unwrap(api.patch) + + payload = {"name": "x"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch(payload), + ): + with pytest.raises(Forbidden): + method(api, "tag-1") + + def test_delete_success(self, app, admin_user): + api = TagUpdateDeleteApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, "tenant-1"), + ), + patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock, + ): + result, status = method(api, "tag-1") + + delete_mock.assert_called_once_with("tag-1") + assert status == 204 + + +class TestTagBindingCreateApi: + def test_create_success(self, app, admin_user, payload_patch): + api = TagBindingCreateApi() + method = unwrap(api.post) + + payload = { + "tag_ids": ["tag-1"], + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock, + ): + result, status = method(api) + + save_mock.assert_called_once() + assert status == 200 + assert result["result"] == "success" + + def test_create_forbidden(self, app, readonly_user, payload_patch): + api = TagBindingCreateApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={}): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch({}), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestTagBindingDeleteApi: + def test_remove_success(self, app, admin_user, payload_patch): + api = TagBindingDeleteApi() + method = unwrap(api.post) + + payload = { + "tag_id": "tag-1", + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, + ): + result, status = method(api) + + delete_mock.assert_called_once() + assert status == 200 + assert result["result"] == "success" + + def test_remove_forbidden(self, app, readonly_user, payload_patch): + api = TagBindingDeleteApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={}): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch({}), + ): + with pytest.raises(Forbidden): + method(api) diff --git a/api/tests/unit_tests/controllers/console/test_admin.py b/api/tests/unit_tests/controllers/console/test_admin.py index e0ddf6542e..16197fcd0c 100644 --- a/api/tests/unit_tests/controllers/console/test_admin.py +++ b/api/tests/unit_tests/controllers/console/test_admin.py @@ -1,13 +1,483 @@ """Final working unit tests for admin endpoints - tests business logic directly.""" import uuid -from unittest.mock import Mock, patch +from unittest.mock import Mock, PropertyMock, patch import pytest from werkzeug.exceptions import NotFound, Unauthorized -from controllers.console.admin import InsertExploreAppPayload -from models.model import App, RecommendedApp +from controllers.console.admin import ( + DeleteExploreBannerApi, + InsertExploreAppApi, + InsertExploreAppListApi, + InsertExploreAppPayload, + InsertExploreBannerApi, + InsertExploreBannerPayload, +) +from models.model import App, InstalledApp, RecommendedApp + + +@pytest.fixture(autouse=True) +def bypass_only_edition_cloud(mocker): + """ + Bypass only_edition_cloud decorator by setting EDITION to "CLOUD". + """ + mocker.patch( + "controllers.console.wraps.dify_config.EDITION", + new="CLOUD", + ) + + +@pytest.fixture +def mock_admin_auth(mocker): + """ + Provide valid admin authentication for controller tests. + """ + mocker.patch( + "controllers.console.admin.dify_config.ADMIN_API_KEY", + "test-admin-key", + ) + mocker.patch( + "controllers.console.admin.extract_access_token", + return_value="test-admin-key", + ) + + +@pytest.fixture +def mock_console_payload(mocker): + payload = { + "app_id": str(uuid.uuid4()), + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + mocker.patch( + "flask_restx.namespace.Namespace.payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return payload + + +@pytest.fixture +def mock_banner_payload(mocker): + mocker.patch( + "flask_restx.namespace.Namespace.payload", + new_callable=PropertyMock, + return_value={ + "title": "Test Banner", + "description": "Banner description", + "img-src": "https://example.com/banner.png", + "link": "https://example.com", + "sort": 1, + "category": "homepage", + }, + ) + + +@pytest.fixture +def mock_session_factory(mocker): + mock_session = Mock() + mock_session.execute = Mock() + mock_session.add = Mock() + mock_session.commit = Mock() + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + +class TestDeleteExploreBannerApi: + def setup_method(self): + self.api = DeleteExploreBannerApi() + + def test_delete_banner_not_found(self, mocker, mock_admin_auth): + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: None), + ) + + with pytest.raises(NotFound, match="is not found"): + self.api.delete(uuid.uuid4()) + + def test_delete_banner_success(self, mocker, mock_admin_auth): + mock_banner = Mock() + + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: mock_banner), + ) + mocker.patch("controllers.console.admin.db.session.delete") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.delete(uuid.uuid4()) + + assert status == 204 + assert response["result"] == "success" + + +class TestInsertExploreBannerApi: + def setup_method(self): + self.api = InsertExploreBannerApi() + + def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload): + mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 201 + assert response["result"] == "success" + + def test_banner_payload_valid_language(self): + payload = { + "title": "Test Banner", + "description": "Banner description", + "img-src": "https://example.com/banner.png", + "link": "https://example.com", + "sort": 1, + "category": "homepage", + "language": "en-US", + } + + model = InsertExploreBannerPayload.model_validate(payload) + assert model.language == "en-US" + + def test_banner_payload_invalid_language(self): + payload = { + "title": "Test Banner", + "description": "Banner description", + "img-src": "https://example.com/banner.png", + "link": "https://example.com", + "sort": 1, + "category": "homepage", + "language": "invalid-lang", + } + + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreBannerPayload.model_validate(payload) + + +class TestInsertExploreAppApiDelete: + def setup_method(self): + self.api = InsertExploreAppApi() + + def test_delete_when_not_in_explore(self, mocker, mock_admin_auth): + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: s, + __exit__=Mock(return_value=False), + execute=lambda *_: Mock(scalar_one_or_none=lambda: None), + ), + ) + + response, status = self.api.delete(uuid.uuid4()) + + assert status == 204 + assert response["result"] == "success" + + def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth): + """Test deleting an app from explore that has a trial app.""" + app_id = uuid.uuid4() + + mock_recommended = Mock(spec=RecommendedApp) + mock_recommended.app_id = "app-123" + + mock_app = Mock(spec=App) + mock_app.is_public = True + + mock_trial = Mock() + + # Mock session context manager and its execute + mock_session = Mock() + mock_session.execute = Mock() + mock_session.delete = Mock() + + # Set up side effects for execute calls + mock_session.execute.side_effect = [ + Mock(scalar_one_or_none=lambda: mock_recommended), + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalars=Mock(return_value=Mock(all=lambda: []))), + Mock(scalar_one_or_none=lambda: mock_trial), + ] + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + mocker.patch("controllers.console.admin.db.session.delete") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.delete(app_id) + + assert status == 204 + assert response["result"] == "success" + assert mock_app.is_public is False + + def test_delete_with_installed_apps(self, mocker, mock_admin_auth): + """Test deleting an app that has installed apps in other tenants.""" + app_id = uuid.uuid4() + + mock_recommended = Mock(spec=RecommendedApp) + mock_recommended.app_id = "app-123" + + mock_app = Mock(spec=App) + mock_app.is_public = True + + mock_installed_app = Mock(spec=InstalledApp) + + # Mock session + mock_session = Mock() + mock_session.execute = Mock() + mock_session.delete = Mock() + + mock_session.execute.side_effect = [ + Mock(scalar_one_or_none=lambda: mock_recommended), + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))), + Mock(scalar_one_or_none=lambda: None), + ] + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + mocker.patch("controllers.console.admin.db.session.delete") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.delete(app_id) + + assert status == 204 + assert mock_session.delete.called + + +class TestInsertExploreAppListApi: + def setup_method(self): + self.api = InsertExploreAppListApi() + + def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload): + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: None), + ) + + with pytest.raises(NotFound, match="is not found"): + self.api.post() + + def test_create_recommended_app( + self, + mocker, + mock_admin_auth, + mock_console_payload, + ): + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.tenant_id = "tenant" + mock_app.is_public = False + + # db.session.execute → fetch App + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: mock_app), + ) + + # session_factory.create_session → recommended_app lookup + mock_session = Mock() + mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None)) + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 201 + assert response["result"] == "success" + assert mock_app.is_public is True + + def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory): + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.is_public = False + + mock_recommended = Mock(spec=RecommendedApp) + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: mock_recommended), + ], + ) + + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True + + def test_site_data_overrides_payload( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + site = Mock() + site.description = "Site Desc" + site.copyright = "Site Copyright" + site.privacy_policy = "Site Privacy" + site.custom_disclaimer = "Site Disclaimer" + + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = site + mock_app.tenant_id = "tenant" + mock_app.is_public = False + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: None), + Mock(scalar_one_or_none=lambda: None), + ], + ) + + commit_spy = mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True + commit_spy.assert_called_once() + + def test_create_trial_app_when_can_trial_enabled( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + mock_console_payload["can_trial"] = True + mock_console_payload["trial_limit"] = 5 + + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.tenant_id = "tenant" + mock_app.is_public = False + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: None), + Mock(scalar_one_or_none=lambda: None), + ], + ) + + add_spy = mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + self.api.post() + + assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list) + + def test_update_recommended_app_with_trial( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + """Test updating a recommended app when trial is enabled.""" + mock_console_payload["can_trial"] = True + mock_console_payload["trial_limit"] = 10 + + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.is_public = False + mock_app.tenant_id = "tenant-123" + + mock_recommended = Mock(spec=RecommendedApp) + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: mock_recommended), + Mock(scalar_one_or_none=lambda: None), + ], + ) + + add_spy = mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True + + def test_update_recommended_app_without_trial( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + """Test updating a recommended app without trial enabled.""" + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.is_public = False + + mock_recommended = Mock(spec=RecommendedApp) + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: mock_recommended), + ], + ) + + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True class TestInsertExploreAppPayload: diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py new file mode 100644 index 0000000000..018257f815 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -0,0 +1,138 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console.apikey import ( + BaseApiKeyListResource, + BaseApiKeyResource, + _get_resource, +) + + +@pytest.fixture +def tenant_context_admin(): + with patch("controllers.console.apikey.current_account_with_tenant") as mock: + user = MagicMock() + user.is_admin_or_owner = True + mock.return_value = (user, "tenant-123") + yield mock + + +@pytest.fixture +def tenant_context_non_admin(): + with patch("controllers.console.apikey.current_account_with_tenant") as mock: + user = MagicMock() + user.is_admin_or_owner = False + mock.return_value = (user, "tenant-123") + yield mock + + +@pytest.fixture +def db_mock(): + with patch("controllers.console.apikey.db") as mock_db: + mock_db.session = MagicMock() + yield mock_db + + +@pytest.fixture(autouse=True) +def bypass_permissions(): + with patch( + "controllers.console.apikey.edit_permission_required", + lambda f: f, + ): + yield + + +class DummyApiKeyListResource(BaseApiKeyListResource): + resource_type = "app" + resource_model = MagicMock() + resource_id_field = "app_id" + token_prefix = "app-" + + +class DummyApiKeyResource(BaseApiKeyResource): + resource_type = "app" + resource_model = MagicMock() + resource_id_field = "app_id" + + +class TestGetResource: + def test_get_resource_success(self): + fake_resource = MagicMock() + + with ( + patch("controllers.console.apikey.select") as mock_select, + patch("controllers.console.apikey.Session") as mock_session, + patch("controllers.console.apikey.db") as mock_db, + ): + mock_db.engine = MagicMock() + mock_select.return_value.filter_by.return_value = MagicMock() + + session = mock_session.return_value.__enter__.return_value + session.execute.return_value.scalar_one_or_none.return_value = fake_resource + + result = _get_resource("rid", "tid", MagicMock) + assert result == fake_resource + + def test_get_resource_not_found(self): + with ( + patch("controllers.console.apikey.select") as mock_select, + patch("controllers.console.apikey.Session") as mock_session, + patch("controllers.console.apikey.db") as mock_db, + patch("controllers.console.apikey.flask_restx.abort") as abort, + ): + mock_db.engine = MagicMock() + mock_select.return_value.filter_by.return_value = MagicMock() + + session = mock_session.return_value.__enter__.return_value + session.execute.return_value.scalar_one_or_none.return_value = None + + _get_resource("rid", "tid", MagicMock) + + abort.assert_called_once() + + +class TestBaseApiKeyListResource: + def test_get_apikeys_success(self, tenant_context_admin, db_mock): + resource = DummyApiKeyListResource() + + with patch("controllers.console.apikey._get_resource"): + db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()] + + result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id") + assert "items" in result + + +class TestBaseApiKeyResource: + def test_delete_forbidden(self, tenant_context_non_admin, db_mock): + resource = DummyApiKeyResource() + + with patch("controllers.console.apikey._get_resource"): + with pytest.raises(Forbidden): + DummyApiKeyResource.delete(resource, "rid", "kid") + + def test_delete_key_not_found(self, tenant_context_admin, db_mock): + resource = DummyApiKeyResource() + db_mock.session.query.return_value.where.return_value.first.return_value = None + + with patch("controllers.console.apikey._get_resource"): + with pytest.raises(Exception) as exc_info: + DummyApiKeyResource.delete(resource, "rid", "kid") + + # flask_restx.abort raises HTTPException with message in data attribute + assert exc_info.value.data["message"] == "API key not found" + + def test_delete_success(self, tenant_context_admin, db_mock): + resource = DummyApiKeyResource() + db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock() + + with ( + patch("controllers.console.apikey._get_resource"), + patch("controllers.console.apikey.ApiTokenCache.delete"), + ): + result, status = DummyApiKeyResource.delete(resource, "rid", "kid") + + assert status == 204 + assert result == {"result": "success"} + db_mock.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py deleted file mode 100644 index b9bc42fb25..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py +++ /dev/null @@ -1,46 +0,0 @@ -import builtins -from unittest.mock import patch - -import pytest -from flask import Flask -from flask.views import MethodView - -from extensions import ext_fastopenapi - -if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView # type: ignore[attr-defined] - - -@pytest.fixture -def app() -> Flask: - app = Flask(__name__) - app.config["TESTING"] = True - app.secret_key = "test-secret-key" - return app - - -def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch): - ext_fastopenapi.init_app(app) - monkeypatch.delenv("INIT_PASSWORD", raising=False) - - with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"): - client = app.test_client() - response = client.get("/console/api/init") - - assert response.status_code == 200 - assert response.get_json() == {"status": "finished"} - - -def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch): - ext_fastopenapi.init_app(app) - monkeypatch.setenv("INIT_PASSWORD", "test-init-password") - - with ( - patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"), - patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0), - ): - client = app.test_client() - response = client.post("/console/api/init", json={"password": "test-init-password"}) - - assert response.status_code == 201 - assert response.get_json() == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py deleted file mode 100644 index c0a984e216..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py +++ /dev/null @@ -1,286 +0,0 @@ -"""Tests for remote file upload API endpoints using Flask-RESTX.""" - -import contextlib -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import Mock, patch - -import httpx -import pytest -from flask import Flask, g - - -@pytest.fixture -def app() -> Flask: - """Create Flask app for testing.""" - app = Flask(__name__) - app.config["TESTING"] = True - app.config["SECRET_KEY"] = "test-secret-key" - return app - - -@pytest.fixture -def client(app): - """Create test client with console blueprint registered.""" - from controllers.console import bp - - app.register_blueprint(bp) - return app.test_client() - - -@pytest.fixture -def mock_account(): - """Create a mock account for testing.""" - from models import Account - - account = Mock(spec=Account) - account.id = "test-account-id" - account.current_tenant_id = "test-tenant-id" - return account - - -@pytest.fixture -def auth_ctx(app, mock_account): - """Context manager to set auth/tenant context in flask.g for a request.""" - - @contextlib.contextmanager - def _ctx(): - with app.test_request_context(): - g._login_user = mock_account - g._current_tenant = mock_account.current_tenant_id - yield - - return _ctx - - -class TestGetRemoteFileInfo: - """Test GET /console/api/remote-files/ endpoint.""" - - def test_get_remote_file_info_success(self, app, client, mock_account): - """Test successful retrieval of remote file info.""" - response = httpx.Response( - 200, - request=httpx.Request("HEAD", "http://example.com/file.txt"), - headers={"Content-Type": "text/plain", "Content-Length": "1024"}, - ) - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response), - patch("libs.login.check_csrf_token", return_value=None), - ): - with app.test_request_context(): - g._login_user = mock_account - g._current_tenant = mock_account.current_tenant_id - encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt" - resp = client.get(f"/console/api/remote-files/{encoded_url}") - - assert resp.status_code == 200 - data = resp.get_json() - assert data["file_type"] == "text/plain" - assert data["file_length"] == 1024 - - def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account): - """Test fallback to GET when HEAD returns non-200 status.""" - head_response = httpx.Response( - 404, - request=httpx.Request("HEAD", "http://example.com/file.pdf"), - ) - get_response = httpx.Response( - 200, - request=httpx.Request("GET", "http://example.com/file.pdf"), - headers={"Content-Type": "application/pdf", "Content-Length": "2048"}, - ) - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response), - patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response), - patch("libs.login.check_csrf_token", return_value=None), - ): - with app.test_request_context(): - g._login_user = mock_account - g._current_tenant = mock_account.current_tenant_id - encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf" - resp = client.get(f"/console/api/remote-files/{encoded_url}") - - assert resp.status_code == 200 - data = resp.get_json() - assert data["file_type"] == "application/pdf" - assert data["file_length"] == 2048 - - -class TestRemoteFileUpload: - """Test POST /console/api/remote-files/upload endpoint.""" - - @pytest.mark.parametrize( - ("head_status", "use_get"), - [ - (200, False), # HEAD succeeds - (405, True), # HEAD fails -> fallback GET - ], - ) - def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get): - url = "http://example.com/file.pdf" - head_resp = httpx.Response( - head_status, - request=httpx.Request("HEAD", url), - headers={"Content-Type": "application/pdf", "Content-Length": "1024"}, - ) - get_resp = httpx.Response( - 200, - request=httpx.Request("GET", url), - headers={"Content-Type": "application/pdf", "Content-Length": "1024"}, - content=b"file content", - ) - - file_info = SimpleNamespace( - extension="pdf", - size=1024, - filename="file.pdf", - mimetype="application/pdf", - ) - uploaded_file = SimpleNamespace( - id="uploaded-file-id", - name="file.pdf", - size=1024, - extension="pdf", - mime_type="application/pdf", - created_by="test-account-id", - created_at=datetime(2024, 1, 1, 12, 0, 0), - ) - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head, - patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get, - patch( - "controllers.console.remote_files.helpers.guess_file_info_from_response", - return_value=file_info, - ), - patch( - "controllers.console.remote_files.FileService.is_file_size_within_limit", - return_value=True, - ), - patch("controllers.console.remote_files.db", spec=["engine"]), - patch("controllers.console.remote_files.FileService") as mock_file_service, - patch( - "controllers.console.remote_files.file_helpers.get_signed_file_url", - return_value="http://example.com/signed-url", - ), - patch("libs.login.check_csrf_token", return_value=None), - ): - mock_file_service.return_value.upload_file.return_value = uploaded_file - - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": url}, - ) - - assert resp.status_code == 201 - p_head.assert_called_once() - # GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds - p_get.assert_called_once() - mock_file_service.return_value.upload_file.assert_called_once() - - data = resp.get_json() - assert data["id"] == "uploaded-file-id" - assert data["name"] == "file.pdf" - assert data["size"] == 1024 - assert data["extension"] == "pdf" - assert data["url"] == "http://example.com/signed-url" - assert data["mime_type"] == "application/pdf" - assert data["created_by"] == "test-account-id" - - @pytest.mark.parametrize( - ("size_ok", "raises", "expected_status", "expected_msg"), - [ - # When size check fails in controller, API returns 413 with message "File size exceeded..." - (False, None, 413, "file size exceeded"), - # When service raises unsupported type, controller maps to 415 with message "File type not allowed." - (True, "unsupported", 415, "file type not allowed"), - ], - ) - def test_upload_remote_file_errors( - self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg - ): - url = "http://example.com/x.pdf" - head_resp = httpx.Response( - 200, - request=httpx.Request("HEAD", url), - headers={"Content-Type": "application/pdf", "Content-Length": "9"}, - ) - file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf") - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp), - patch( - "controllers.console.remote_files.helpers.guess_file_info_from_response", - return_value=file_info, - ), - patch( - "controllers.console.remote_files.FileService.is_file_size_within_limit", - return_value=size_ok, - ), - patch("controllers.console.remote_files.db", spec=["engine"]), - patch("libs.login.check_csrf_token", return_value=None), - ): - if raises == "unsupported": - from services.errors.file import UnsupportedFileTypeError - - with patch("controllers.console.remote_files.FileService") as mock_file_service: - mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad") - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": url}, - ) - else: - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": url}, - ) - - assert resp.status_code == expected_status - data = resp.get_json() - msg = (data.get("error") or {}).get("message") or data.get("message", "") - assert expected_msg in msg.lower() - - def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx): - """Test upload when fetching of remote file fails.""" - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch( - "controllers.console.remote_files.ssrf_proxy.head", - side_effect=httpx.RequestError("Connection failed"), - ), - patch("libs.login.check_csrf_token", return_value=None), - ): - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": "http://unreachable.com/file.pdf"}, - ) - - assert resp.status_code == 400 - data = resp.get_json() - msg = (data.get("error") or {}).get("message") or data.get("message", "") - assert "failed to fetch" in msg.lower() diff --git a/api/tests/unit_tests/controllers/console/test_feature.py b/api/tests/unit_tests/controllers/console/test_feature.py new file mode 100644 index 0000000000..d8debc1f2c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_feature.py @@ -0,0 +1,81 @@ +from werkzeug.exceptions import Unauthorized + + +def unwrap(func): + """ + Recursively unwrap decorated functions. + """ + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestFeatureApi: + def test_get_tenant_features_success(self, mocker): + from controllers.console.feature import FeatureApi + + mocker.patch( + "controllers.console.feature.current_account_with_tenant", + return_value=("account_id", "tenant_123"), + ) + + mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = { + "features": {"feature_a": True} + } + + api = FeatureApi() + + raw_get = unwrap(FeatureApi.get) + result = raw_get(api) + + assert result == {"features": {"feature_a": True}} + + +class TestSystemFeatureApi: + def test_get_system_features_authenticated(self, mocker): + """ + current_user.is_authenticated == True + """ + + from controllers.console.feature import SystemFeatureApi + + fake_user = mocker.Mock() + fake_user.is_authenticated = True + + mocker.patch( + "controllers.console.feature.current_user", + fake_user, + ) + + mocker.patch( + "controllers.console.feature.FeatureService.get_system_features" + ).return_value.model_dump.return_value = {"features": {"sys_feature": True}} + + api = SystemFeatureApi() + result = api.get() + + assert result == {"features": {"sys_feature": True}} + + def test_get_system_features_unauthenticated(self, mocker): + """ + current_user.is_authenticated raises Unauthorized + """ + + from controllers.console.feature import SystemFeatureApi + + fake_user = mocker.Mock() + type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized()) + + mocker.patch( + "controllers.console.feature.current_user", + fake_user, + ) + + mocker.patch( + "controllers.console.feature.FeatureService.get_system_features" + ).return_value.model_dump.return_value = {"features": {"sys_feature": False}} + + api = SystemFeatureApi() + result = api.get() + + assert result == {"features": {"sys_feature": False}} diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py new file mode 100644 index 0000000000..5df9daa7f8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -0,0 +1,300 @@ +import io +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from constants import DOCUMENT_EXTENSIONS +from controllers.common.errors import ( + BlockedFileExtensionError, + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from controllers.console.files import ( + FileApi, + FilePreviewApi, + FileSupportTypeApi, +) + + +def unwrap(func): + """ + Recursively unwrap decorated functions. + """ + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.testing = True + return app + + +@pytest.fixture(autouse=True) +def mock_decorators(): + """ + Make decorators no-ops so logic is directly testable + """ + with ( + patch("controllers.console.files.setup_required", new=lambda f: f), + patch("controllers.console.files.login_required", new=lambda f: f), + patch("controllers.console.files.account_initialization_required", new=lambda f: f), + patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f), + ): + yield + + +@pytest.fixture +def mock_current_user(): + user = MagicMock() + user.is_dataset_editor = True + return user + + +@pytest.fixture +def mock_account_context(mock_current_user): + with patch( + "controllers.console.files.current_account_with_tenant", + return_value=(mock_current_user, None), + ): + yield + + +@pytest.fixture +def mock_db(): + with patch("controllers.console.files.db") as db_mock: + db_mock.engine = MagicMock() + yield db_mock + + +@pytest.fixture +def mock_file_service(mock_db): + with patch("controllers.console.files.FileService") as fs: + instance = fs.return_value + yield instance + + +class TestFileApiGet: + def test_get_upload_config(self, app): + api = FileApi() + get_method = unwrap(api.get) + + with app.test_request_context(): + data, status = get_method(api) + + assert status == 200 + assert "file_size_limit" in data + assert "batch_count_limit" in data + + +class TestFileApiPost: + def test_no_file_uploaded(self, app, mock_account_context): + api = FileApi() + post_method = unwrap(api.post) + + with app.test_request_context(method="POST", data={}): + with pytest.raises(NoFileUploadedError): + post_method(api) + + def test_too_many_files(self, app, mock_account_context): + api = FileApi() + post_method = unwrap(api.post) + + with app.test_request_context(method="POST"): + from unittest.mock import MagicMock, patch + + with patch("controllers.console.files.request") as mock_request: + mock_request.files = MagicMock() + mock_request.files.__len__.return_value = 2 + mock_request.files.__contains__.return_value = True + mock_request.form = MagicMock() + mock_request.form.get.return_value = None + + with pytest.raises(TooManyFilesError): + post_method(api) + + def test_filename_missing(self, app, mock_account_context): + api = FileApi() + post_method = unwrap(api.post) + + data = { + "file": (io.BytesIO(b"abc"), ""), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(FilenameNotExistsError): + post_method(api) + + def test_dataset_upload_without_permission(self, app, mock_current_user): + mock_current_user.is_dataset_editor = False + + with patch( + "controllers.console.files.current_account_with_tenant", + return_value=(mock_current_user, None), + ): + api = FileApi() + post_method = unwrap(api.post) + + data = { + "file": (io.BytesIO(b"abc"), "test.txt"), + "source": "datasets", + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(Forbidden): + post_method(api) + + def test_successful_upload(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + mock_file = MagicMock() + mock_file.id = "file-id-123" + mock_file.filename = "test.txt" + mock_file.name = "test.txt" + mock_file.size = 1024 + mock_file.extension = "txt" + mock_file.mime_type = "text/plain" + mock_file.created_by = "user-123" + mock_file.created_at = 1234567890 + mock_file.preview_url = "http://example.com/preview/file-id-123" + mock_file.source_url = "http://example.com/source/file-id-123" + mock_file.original_url = None + mock_file.user_id = "user-123" + mock_file.tenant_id = "tenant-123" + mock_file.conversation_id = None + mock_file.file_key = "file-key-123" + + mock_file_service.upload_file.return_value = mock_file + + data = { + "file": (io.BytesIO(b"hello"), "test.txt"), + } + + with app.test_request_context(method="POST", data=data): + response, status = post_method(api) + + assert status == 201 + assert response["id"] == "file-id-123" + assert response["name"] == "test.txt" + + def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service): + """Test that invalid source parameter gets normalized to None""" + api = FileApi() + post_method = unwrap(api.post) + + # Create a properly structured mock file object + mock_file = MagicMock() + mock_file.id = "file-id-456" + mock_file.filename = "test.txt" + mock_file.name = "test.txt" + mock_file.size = 512 + mock_file.extension = "txt" + mock_file.mime_type = "text/plain" + mock_file.created_by = "user-456" + mock_file.created_at = 1234567890 + mock_file.preview_url = None + mock_file.source_url = None + mock_file.original_url = None + mock_file.user_id = "user-456" + mock_file.tenant_id = "tenant-456" + mock_file.conversation_id = None + mock_file.file_key = "file-key-456" + + mock_file_service.upload_file.return_value = mock_file + + data = { + "file": (io.BytesIO(b"content"), "test.txt"), + "source": "invalid_source", # Should be normalized to None + } + + with app.test_request_context(method="POST", data=data): + response, status = post_method(api) + + assert status == 201 + assert response["id"] == "file-id-456" + # Verify that FileService was called with source=None + mock_file_service.upload_file.assert_called_once() + call_kwargs = mock_file_service.upload_file.call_args[1] + assert call_kwargs["source"] is None + + def test_file_too_large_error(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + from services.errors.file import FileTooLargeError as ServiceFileTooLargeError + + error = ServiceFileTooLargeError("File is too large") + mock_file_service.upload_file.side_effect = error + + data = { + "file": (io.BytesIO(b"x" * 1000000), "big.txt"), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(FileTooLargeError): + post_method(api) + + def test_unsupported_file_type(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + error = ServiceUnsupportedFileTypeError() + mock_file_service.upload_file.side_effect = error + + data = { + "file": (io.BytesIO(b"x"), "bad.exe"), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(UnsupportedFileTypeError): + post_method(api) + + def test_blocked_extension(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError + + error = ServiceBlockedFileExtensionError("File extension is blocked") + mock_file_service.upload_file.side_effect = error + + data = { + "file": (io.BytesIO(b"x"), "blocked.txt"), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(BlockedFileExtensionError): + post_method(api) + + +class TestFilePreviewApi: + def test_get_preview(self, app, mock_file_service): + api = FilePreviewApi() + get_method = unwrap(api.get) + mock_file_service.get_file_preview.return_value = "preview text" + + with app.test_request_context(): + result = get_method(api, "1234") + + assert result == {"content": "preview text"} + + +class TestFileSupportTypeApi: + def test_get_supported_types(self, app): + api = FileSupportTypeApi() + get_method = unwrap(api.get) + + with app.test_request_context(): + result = get_method(api) + + assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)} diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py new file mode 100644 index 0000000000..232b6eee79 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from flask import Response + +from controllers.console.human_input_form import ( + ConsoleHumanInputFormApi, + ConsoleWorkflowEventsApi, + DifyAPIRepositoryFactory, + WorkflowResponseConverter, + _jsonify_form_definition, +) +from controllers.web.error import NotFoundError +from models.enums import CreatorUserRole +from models.human_input import RecipientType +from models.model import AppMode + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def test_jsonify_form_definition() -> None: + expiration = datetime(2024, 1, 1, tzinfo=UTC) + definition = SimpleNamespace(model_dump=lambda: {"fields": []}) + form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration) + + response = _jsonify_form_definition(form) + + assert isinstance(response, Response) + payload = json.loads(response.get_data(as_text=True)) + assert payload["expiration_time"] == int(expiration.timestamp()) + + +def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace(tenant_id="tenant-1") + monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2")) + + with pytest.raises(NotFoundError): + ConsoleHumanInputFormApi._ensure_console_access(form) + + +def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + expiration = datetime(2024, 1, 1, tzinfo=UTC) + definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]}) + form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_definition_by_token_for_console(self, _token): + return form + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1")) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/form/human_input/token", method="GET"): + response = handler(api, form_token="token") + + payload = json.loads(response.get_data(as_text=True)) + assert payload["fields"] == ["a"] + + +def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_definition_by_token_for_console(self, _token): + return None + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1")) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/form/human_input/token", method="GET"): + with pytest.raises(NotFoundError): + handler(api, form_token="token") + + +def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + with pytest.raises(NotFoundError): + handler(api, form_token="token") + + +def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + submit_mock = Mock() + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + def submit_form_by_token(self, **kwargs): + submit_mock(**kwargs) + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + response = handler(api, form_token="token") + + assert response.get_json() == {} + submit_mock.assert_called_once() + + +def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return None + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + with pytest.raises(NotFoundError): + handler(api, workflow_run_id="run-1") + + +def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + tenant_id="t1", + ) + + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return workflow_run + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + with pytest.raises(NotFoundError): + handler(api, workflow_run_id="run-1") + + +def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user-2", + tenant_id="t1", + ) + + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return workflow_run + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + with pytest.raises(NotFoundError): + handler(api, workflow_run_id="run-1") + + +def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user-1", + tenant_id="t1", + app_id="app-1", + finished_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + app_model = SimpleNamespace(mode=AppMode.WORKFLOW) + + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return workflow_run + + response_obj = SimpleNamespace( + event=SimpleNamespace(value="finished"), + model_dump=lambda mode="json": {"status": "done"}, + ) + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form._retrieve_app_for_workflow_run", + lambda *_args, **_kwargs: app_model, + ) + monkeypatch.setattr( + WorkflowResponseConverter, + "workflow_run_result_to_finish_response", + lambda **_kwargs: response_obj, + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + response = handler(api, workflow_run_id="run-1") + + assert response.mimetype == "text/event-stream" + assert "data" in response.get_data(as_text=True) diff --git a/api/tests/unit_tests/controllers/console/test_init_validate.py b/api/tests/unit_tests/controllers/console/test_init_validate.py new file mode 100644 index 0000000000..3077304cbe --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_init_validate.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from controllers.console import init_validate +from controllers.console.error import AlreadySetupError, InitValidateFailedError + + +class _SessionStub: + def __init__(self, has_setup: bool): + self._has_setup = has_setup + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, *_args, **_kwargs): + return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None) + + +def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True) + result = init_validate.get_init_status() + assert result.status == "finished" + + +def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False) + result = init_validate.get_init_status() + assert result.status == "not_started" + + +def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1) + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="POST"): + with pytest.raises(AlreadySetupError): + init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw")) + + +def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setenv("INIT_PASSWORD", "expected") + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="POST"): + with pytest.raises(InitValidateFailedError): + init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong")) + assert init_validate.session.get("is_init_validated") is False + + +def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setenv("INIT_PASSWORD", "expected") + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="POST"): + result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected")) + assert result.result == "success" + assert init_validate.session.get("is_init_validated") is True + + +def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD") + assert init_validate.get_init_validate_status() is True + + +def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setenv("INIT_PASSWORD", "expected") + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="GET"): + init_validate.session["is_init_validated"] = True + assert init_validate.get_init_validate_status() is True + + +def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setenv("INIT_PASSWORD", "expected") + monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True)) + monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object())) + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="GET"): + init_validate.session.pop("is_init_validated", None) + assert init_validate.get_init_validate_status() is True + + +def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setenv("INIT_PASSWORD", "expected") + monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False)) + monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object())) + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="GET"): + init_validate.session.pop("is_init_validated", None) + assert init_validate.get_init_validate_status() is False diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py new file mode 100644 index 0000000000..1be402c8ab --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import urllib.parse +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import httpx +import pytest + +from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError +from controllers.console import remote_files as remote_files_module +from services.errors.file import FileTooLargeError as ServiceFileTooLargeError +from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class _FakeResponse: + def __init__( + self, + *, + status_code: int = 200, + headers: dict[str, str] | None = None, + method: str = "GET", + content: bytes = b"", + text: str = "", + error: Exception | None = None, + ) -> None: + self.status_code = status_code + self.headers = headers or {} + self.request = SimpleNamespace(method=method) + self.content = content + self.text = text + self._error = error + + def raise_for_status(self) -> None: + if self._error: + raise self._error + + +def _mock_upload_dependencies( + monkeypatch: pytest.MonkeyPatch, + *, + file_size_within_limit: bool = True, +): + file_info = SimpleNamespace( + filename="report.txt", + extension=".txt", + mimetype="text/plain", + size=3, + ) + monkeypatch.setattr( + remote_files_module.helpers, + "guess_file_info_from_response", + MagicMock(return_value=file_info), + ) + + file_service_cls = MagicMock() + file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit + monkeypatch.setattr(remote_files_module, "FileService", file_service_cls) + monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None)) + monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + remote_files_module.file_helpers, + "get_signed_file_url", + lambda upload_file_id: f"https://signed.example/{upload_file_id}", + ) + + return file_service_cls + + +def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.GetRemoteFileInfo() + handler = _unwrap(api.get) + decoded_url = "https://example.com/test.txt" + encoded_url = urllib.parse.quote(decoded_url, safe="") + + head_resp = _FakeResponse( + status_code=200, + headers={"Content-Type": "text/plain", "Content-Length": "128"}, + method="HEAD", + ) + head_mock = MagicMock(return_value=head_resp) + get_mock = MagicMock() + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + with app.test_request_context(method="GET"): + payload = handler(api, url=encoded_url) + + assert payload == {"file_type": "text/plain", "file_length": 128} + head_mock.assert_called_once_with(decoded_url) + get_mock.assert_not_called() + + +def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.GetRemoteFileInfo() + handler = _unwrap(api.get) + decoded_url = "https://example.com/test.txt" + encoded_url = urllib.parse.quote(decoded_url, safe="") + + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503))) + get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET")) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + with app.test_request_context(method="GET"): + payload = handler(api, url=encoded_url) + + assert payload == {"file_type": "application/octet-stream", "file_length": 0} + get_mock.assert_called_once_with(decoded_url, timeout=3) + + +def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/report.txt" + + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404))) + get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content") + get_mock = MagicMock(return_value=get_resp) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + file_service_cls = _mock_upload_dependencies(monkeypatch) + upload_file = SimpleNamespace( + id="file-1", + name="report.txt", + size=16, + extension=".txt", + mime_type="text/plain", + created_by="u1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + file_service_cls.return_value.upload_file.return_value = upload_file + + with app.test_request_context(method="POST", json={"url": url}): + payload, status = handler(api) + + assert status == 201 + assert payload["id"] == "file-1" + assert payload["url"] == "https://signed.example/file-1" + get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True) + file_service_cls.return_value.upload_file.assert_called_once_with( + filename="report.txt", + content=b"fallback-content", + mimetype="text/plain", + user=SimpleNamespace(id="u1"), + source_url=url, + ) + + +def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/photo.jpg" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")), + ) + extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content") + get_mock = MagicMock(return_value=extra_get_resp) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + file_service_cls = _mock_upload_dependencies(monkeypatch) + upload_file = SimpleNamespace( + id="file-2", + name="photo.jpg", + size=18, + extension=".jpg", + mime_type="image/jpeg", + created_by="u1", + created_at=datetime(2024, 1, 2, tzinfo=UTC), + ) + file_service_cls.return_value.upload_file.return_value = upload_file + + with app.test_request_context(method="POST", json={"url": url}): + payload, status = handler(api) + + assert status == 201 + assert payload["id"] == "file-2" + get_mock.assert_called_once_with(url) + assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content" + + +def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/fail.txt" + + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500))) + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "get", + MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")), + ) + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"): + handler(api) + + +def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/fail.txt" + + request = httpx.Request("HEAD", url) + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(side_effect=httpx.RequestError("network down", request=request)), + ) + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"): + handler(api) + + +def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/large.bin" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), + ) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + + _mock_upload_dependencies(monkeypatch, file_size_within_limit=False) + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(FileTooLargeError): + handler(api) + + +def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/large.bin" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), + ) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded") + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(FileTooLargeError, match="size exceeded"): + handler(api) + + +def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/file.exe" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), + ) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError() + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(UnsupportedFileTypeError): + handler(api) diff --git a/api/tests/unit_tests/controllers/console/test_spec.py b/api/tests/unit_tests/controllers/console/test_spec.py new file mode 100644 index 0000000000..05a4befaa8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_spec.py @@ -0,0 +1,49 @@ +from unittest.mock import patch + +import controllers.console.spec as spec_module + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestSpecSchemaDefinitionsApi: + def test_get_success(self): + api = spec_module.SpecSchemaDefinitionsApi() + method = unwrap(api.get) + + schema_definitions = [{"type": "string"}] + + with patch.object( + spec_module, + "SchemaManager", + ) as schema_manager_cls: + schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions + + resp, status = method(api) + + assert status == 200 + assert resp == schema_definitions + + def test_get_exception_returns_empty_list(self): + api = spec_module.SpecSchemaDefinitionsApi() + method = unwrap(api.get) + + with ( + patch.object( + spec_module, + "SchemaManager", + side_effect=Exception("boom"), + ), + patch.object( + spec_module.logger, + "exception", + ) as log_exception, + ): + resp, status = method(api) + + assert status == 200 + assert resp == [] + log_exception.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/test_version.py b/api/tests/unit_tests/controllers/console/test_version.py new file mode 100644 index 0000000000..8d8d324be1 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_version.py @@ -0,0 +1,162 @@ +from unittest.mock import MagicMock, patch + +import controllers.console.version as version_module + + +class TestHasNewVersion: + def test_has_new_version_true(self): + result = version_module._has_new_version( + latest_version="1.2.0", + current_version="1.1.0", + ) + assert result is True + + def test_has_new_version_false(self): + result = version_module._has_new_version( + latest_version="1.0.0", + current_version="1.1.0", + ) + assert result is False + + def test_has_new_version_invalid_version(self): + with patch.object(version_module.logger, "warning") as log_warning: + result = version_module._has_new_version( + latest_version="invalid", + current_version="1.0.0", + ) + + assert result is False + log_warning.assert_called_once() + + +class TestCheckVersionUpdate: + def test_no_check_update_url(self): + query = version_module.VersionQuery(current_version="1.0.0") + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "", + ), + patch.object( + version_module.dify_config.project, + "version", + "1.0.0", + ), + patch.object( + version_module.dify_config, + "CAN_REPLACE_LOGO", + True, + ), + patch.object( + version_module.dify_config, + "MODEL_LB_ENABLED", + False, + ), + ): + result = version_module.check_version_update(query) + + assert result.version == "1.0.0" + assert result.can_auto_update is False + assert result.features.can_replace_logo is True + assert result.features.model_load_balancing_enabled is False + + def test_http_error_fallback(self): + query = version_module.VersionQuery(current_version="1.0.0") + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "http://example.com", + ), + patch.object( + version_module.httpx, + "get", + side_effect=Exception("boom"), + ), + patch.object( + version_module.logger, + "warning", + ) as log_warning, + ): + result = version_module.check_version_update(query) + + assert result.version == "1.0.0" + log_warning.assert_called_once() + + def test_new_version_available(self): + query = version_module.VersionQuery(current_version="1.0.0") + + response = MagicMock() + response.json.return_value = { + "version": "1.2.0", + "releaseDate": "2024-01-01", + "releaseNotes": "New features", + "canAutoUpdate": True, + } + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "http://example.com", + ), + patch.object( + version_module.httpx, + "get", + return_value=response, + ), + patch.object( + version_module.dify_config.project, + "version", + "1.0.0", + ), + patch.object( + version_module.dify_config, + "CAN_REPLACE_LOGO", + False, + ), + patch.object( + version_module.dify_config, + "MODEL_LB_ENABLED", + True, + ), + ): + result = version_module.check_version_update(query) + + assert result.version == "1.2.0" + assert result.release_date == "2024-01-01" + assert result.release_notes == "New features" + assert result.can_auto_update is True + + def test_no_new_version(self): + query = version_module.VersionQuery(current_version="1.2.0") + + response = MagicMock() + response.json.return_value = { + "version": "1.1.0", + } + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "http://example.com", + ), + patch.object( + version_module.httpx, + "get", + return_value=response, + ), + patch.object( + version_module.dify_config.project, + "version", + "1.2.0", + ), + ): + result = version_module.check_version_update(query) + + assert result.version == "1.2.0" + assert result.can_auto_update is False diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py new file mode 100644 index 0000000000..00d322fdea --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -0,0 +1,341 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from controllers.console import console_ns +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailCodeError, +) +from controllers.console.error import AccountInFreezeError +from controllers.console.workspace.account import ( + AccountAvatarApi, + AccountDeleteApi, + AccountDeleteVerifyApi, + AccountInitApi, + AccountIntegrateApi, + AccountInterfaceLanguageApi, + AccountInterfaceThemeApi, + AccountNameApi, + AccountPasswordApi, + AccountProfileApi, + AccountTimezoneApi, + ChangeEmailCheckApi, + ChangeEmailResetApi, + CheckEmailUnique, +) +from controllers.console.workspace.error import ( + AccountAlreadyInitedError, + CurrentPasswordIncorrectError, + InvalidAccountDeletionCodeError, +) +from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAccountInitApi: + def test_init_success(self, app): + api = AccountInitApi() + method = unwrap(api.post) + + account = MagicMock(status="inactive") + payload = { + "interface_language": "en-US", + "timezone": "UTC", + "invitation_code": "code123", + } + + with ( + app.test_request_context("/account/init", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + patch("controllers.console.workspace.account.db.session.commit", return_value=None), + patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"), + patch("controllers.console.workspace.account.db.session.query") as query_mock, + ): + query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused") + resp = method(api) + + assert resp["result"] == "success" + + def test_init_already_initialized(self, app): + api = AccountInitApi() + method = unwrap(api.post) + + account = MagicMock(status="active") + + with ( + app.test_request_context("/account/init"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + ): + with pytest.raises(AccountAlreadyInitedError): + method(api) + + +class TestAccountProfileApi: + def test_get_profile_success(self, app): + api = AccountProfileApi() + method = unwrap(api.get) + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/account/profile"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + ): + result = method(api) + + assert result["id"] == "u1" + + +class TestAccountUpdateApis: + @pytest.mark.parametrize( + ("api_cls", "payload"), + [ + (AccountNameApi, {"name": "test"}), + (AccountAvatarApi, {"avatar": "img.png"}), + (AccountInterfaceLanguageApi, {"interface_language": "en-US"}), + (AccountInterfaceThemeApi, {"interface_theme": "dark"}), + (AccountTimezoneApi, {"timezone": "UTC"}), + ], + ) + def test_update_success(self, app, api_cls, payload): + api = api_cls() + method = unwrap(api.post) + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.account.AccountService.update_account", return_value=user), + ): + result = method(api) + + assert result["id"] == "u1" + + +class TestAccountPasswordApi: + def test_password_success(self, app): + api = AccountPasswordApi() + method = unwrap(api.post) + + payload = { + "password": "old", + "new_password": "new123", + "repeat_new_password": "new123", + } + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None), + ): + result = method(api) + + assert result["id"] == "u1" + + def test_password_wrong_current(self, app): + api = AccountPasswordApi() + method = unwrap(api.post) + + payload = { + "password": "bad", + "new_password": "new123", + "repeat_new_password": "new123", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.update_account_password", + side_effect=ServicePwdError(), + ), + ): + with pytest.raises(CurrentPasswordIncorrectError): + method(api) + + +class TestAccountIntegrateApi: + def test_get_integrates(self, app): + api = AccountIntegrateApi() + method = unwrap(api.get) + + account = MagicMock(id="acc1") + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock, + ): + scalars_mock.return_value.all.return_value = [] + result = method(api) + + assert "data" in result + assert len(result["data"]) == 2 + + +class TestAccountDeleteApi: + def test_delete_verify_success(self, app): + api = AccountDeleteVerifyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code", + return_value=("token", "1234"), + ), + patch( + "controllers.console.workspace.account.AccountService.send_account_deletion_verification_email", + return_value=None, + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_delete_invalid_code(self, app): + api = AccountDeleteApi() + method = unwrap(api.post) + + payload = {"token": "t", "code": "x"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.verify_account_deletion_code", + return_value=False, + ), + ): + with pytest.raises(InvalidAccountDeletionCodeError): + method(api) + + +class TestChangeEmailApis: + def test_check_email_code_invalid(self, app): + api = ChangeEmailCheckApi() + method = unwrap(api.post) + + payload = {"email": "a@test.com", "code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.account.AccountService.get_change_email_data", + return_value={"email": "a@test.com", "code": "y"}, + ), + ): + with pytest.raises(EmailCodeError): + method(api) + + def test_reset_email_already_used(self, app): + api = ChangeEmailResetApi() + method = unwrap(api.post) + + payload = {"new_email": "x@test.com", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False), + patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False), + ): + with pytest.raises(EmailAlreadyInUseError): + method(api) + + +class TestCheckEmailUniqueApi: + def test_email_unique_success(self, app): + api = CheckEmailUnique() + method = unwrap(api.post) + + payload = {"email": "ok@test.com"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False), + patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True), + ): + result = method(api) + + assert result["result"] == "success" + + def test_email_in_freeze(self, app): + api = CheckEmailUnique() + method = unwrap(api.post) + + payload = {"email": "x@test.com"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True), + ): + with pytest.raises(AccountInFreezeError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py new file mode 100644 index 0000000000..b4e03f681d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py @@ -0,0 +1,139 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console.error import AccountNotFound +from controllers.console.workspace.agent_providers import ( + AgentProviderApi, + AgentProviderListApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAgentProviderListApi: + def test_get_success(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + providers = [{"name": "openai"}, {"name": "anthropic"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", + return_value=providers, + ), + ): + result = method(api) + + assert result == providers + + def test_get_empty_list(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", + return_value=[], + ), + ): + result = method(api) + + assert result == [] + + def test_get_account_not_found(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + side_effect=AccountNotFound(), + ), + ): + with pytest.raises(AccountNotFound): + method(api) + + +class TestAgentProviderApi: + def test_get_success(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + provider_name = "openai" + provider_data = {"name": "openai", "models": ["gpt-4"]} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", + return_value=provider_data, + ), + ): + result = method(api, provider_name) + + assert result == provider_data + + def test_get_provider_not_found(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + provider_name = "unknown" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", + return_value=None, + ), + ): + result = method(api, provider_name) + + assert result is None + + def test_get_account_not_found(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + side_effect=AccountNotFound(), + ), + ): + with pytest.raises(AccountNotFound): + method(api, "openai") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py new file mode 100644 index 0000000000..51f76af172 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -0,0 +1,305 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console.workspace.endpoint import ( + EndpointCreateApi, + EndpointDeleteApi, + EndpointDisableApi, + EndpointEnableApi, + EndpointListApi, + EndpointListForSinglePluginApi, + EndpointUpdateApi, +) +from core.plugin.impl.exc import PluginPermissionDeniedError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user_and_tenant(): + return MagicMock(id="u1"), "t1" + + +@pytest.fixture +def patch_current_account(user_and_tenant): + with patch( + "controllers.console.workspace.endpoint.current_account_with_tenant", + return_value=user_and_tenant, + ): + yield + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointCreateApi: + def test_create_success(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {"a": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_create_permission_denied(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.endpoint.EndpointService.create_endpoint", + side_effect=PluginPermissionDeniedError("denied"), + ), + ): + with pytest.raises(ValueError): + method(api) + + def test_create_validation_error(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p1", + "name": "", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointListApi: + def test_list_success(self, app): + api = EndpointListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]), + ): + result = method(api) + + assert "endpoints" in result + assert len(result["endpoints"]) == 1 + + def test_list_invalid_query(self, app): + api = EndpointListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=0&page_size=10"), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointListForSinglePluginApi: + def test_list_for_plugin_success(self, app): + api = EndpointListForSinglePluginApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10&plugin_id=p1"), + patch( + "controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin", + return_value=[{"id": "e1"}], + ), + ): + result = method(api) + + assert "endpoints" in result + + def test_list_for_plugin_missing_param(self, app): + api = EndpointListForSinglePluginApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointDeleteApi: + def test_delete_success(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_delete_invalid_payload(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) + + def test_delete_service_failure(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointUpdateApi: + def test_update_success(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = { + "endpoint_id": "e1", + "name": "new-name", + "settings": {"x": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_update_validation_error(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1", "settings": {}} + + with ( + app.test_request_context("/", json=payload), + ): + with pytest.raises(ValueError): + method(api) + + def test_update_service_failure(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = { + "endpoint_id": "e1", + "name": "n", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointEnableApi: + def test_enable_success(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_enable_invalid_payload(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) + + def test_enable_service_failure(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointDisableApi: + def test_disable_success(self, app): + api = EndpointDisableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_disable_invalid_payload(self, app): + api = EndpointDisableApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py new file mode 100644 index 0000000000..b6708d1f6f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -0,0 +1,607 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import HTTPException + +import services +from controllers.console.auth.error import ( + CannotTransferOwnerToSelfError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, + MemberNotInTenantError, + NotOwnerError, + OwnerTransferLimitError, +) +from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded +from controllers.console.workspace.members import ( + DatasetOperatorMemberListApi, + MemberCancelInviteApi, + MemberInviteEmailApi, + MemberListApi, + MemberUpdateRoleApi, + OwnerTransfer, + OwnerTransferCheckApi, + SendOwnerTransferEmailApi, +) +from services.errors.account import AccountAlreadyInTenantError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestMemberListApi: + def test_get_success(self, app): + api = MemberListApi() + method = unwrap(api.get) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + member.id = "m1" + member.name = "Member" + member.email = "member@test.com" + member.avatar = "avatar.png" + member.role = "admin" + member.status = "active" + members = [member] + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members), + ): + result, status = method(api) + + assert status == 200 + assert len(result["accounts"]) == 1 + + def test_get_no_tenant(self, app): + api = MemberListApi() + method = unwrap(api.get) + + user = MagicMock(current_tenant=None) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(ValueError): + method(api) + + +class TestMemberInviteEmailApi: + def test_invite_success(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + "language": "en-US", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, status = method(api) + + assert status == 201 + assert result["result"] == "success" + + def test_invite_limit_exceeded(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = False + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + ): + with pytest.raises(WorkspaceMembersLimitExceeded): + method(api) + + def test_invite_already_member(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch( + "controllers.console.workspace.members.RegisterService.invite_new_member", + side_effect=AccountAlreadyInTenantError(), + ), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, status = method(api) + + assert result["invitation_results"][0]["status"] == "success" + + def test_invite_invalid_role(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + payload = { + "emails": ["a@test.com"], + "role": "owner", + } + + with app.test_request_context("/", json=payload): + result, status = method(api) + + assert status == 400 + assert result["code"] == "invalid-role" + + def test_invite_generic_exception(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch( + "controllers.console.workspace.members.RegisterService.invite_new_member", + side_effect=Exception("boom"), + ), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, _ = method(api) + + assert result["invitation_results"][0]["status"] == "failed" + + +class TestMemberCancelInviteApi: + def test_cancel_success(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 200 + assert result["result"] == "success" + + def test_cancel_not_found(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + ): + q.return_value.where.return_value.first.return_value = None + + with pytest.raises(HTTPException): + method(api, "x") + + def test_cancel_cannot_operate_self(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.CannotOperateSelfError("x"), + ), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 400 + + def test_cancel_no_permission(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.NoPermissionError("x"), + ), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 403 + + def test_cancel_member_not_in_tenant(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.MemberNotInTenantError(), + ), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 404 + + +class TestMemberUpdateRoleApi: + def test_update_success(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + payload = {"role": "normal"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get", return_value=member), + patch("controllers.console.workspace.members.TenantService.update_member_role"), + ): + result = method(api, "id") + + if isinstance(result, tuple): + result = result[0] + + assert result["result"] == "success" + + def test_update_invalid_role(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + payload = {"role": "invalid-role"} + + with app.test_request_context("/", json=payload): + result, status = method(api, "id") + + assert status == 400 + + def test_update_member_not_found(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + payload = {"role": "normal"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.members.current_account_with_tenant", + return_value=(MagicMock(current_tenant=MagicMock()), "t1"), + ), + patch("controllers.console.workspace.members.db.session.get", return_value=None), + ): + with pytest.raises(HTTPException): + method(api, "id") + + +class TestDatasetOperatorMemberListApi: + def test_get_success(self, app): + api = DatasetOperatorMemberListApi() + method = unwrap(api.get) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + member.id = "op1" + member.name = "Operator" + member.email = "operator@test.com" + member.avatar = "avatar.png" + member.role = "operator" + member.status = "active" + members = [member] + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members + ), + ): + result, status = method(api) + + assert status == 200 + assert len(result["accounts"]) == 1 + + def test_get_no_tenant(self, app): + api = DatasetOperatorMemberListApi() + method = unwrap(api.get) + + user = MagicMock(current_tenant=None) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(ValueError): + method(api) + + +class TestSendOwnerTransferEmailApi: + def test_send_success(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(name="ws") + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token" + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_send_ip_limit(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True), + ): + with pytest.raises(EmailSendIpLimitError): + method(api) + + def test_send_not_owner(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/", json={}), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False), + ): + with pytest.raises(NotOwnerError): + method(api) + + +class TestOwnerTransferCheckApi: + def test_check_invalid_code(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "a@test.com", "code": "y"}, + ), + ): + with pytest.raises(EmailCodeError): + method(api) + + def test_rate_limited(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=True, + ), + ): + with pytest.raises(OwnerTransferLimitError): + method(api) + + def test_invalid_token(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), + ): + with pytest.raises(InvalidTokenError): + method(api) + + def test_invalid_email(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "b@test.com", "code": "x"}, + ), + ): + with pytest.raises(InvalidEmailError): + method(api) + + +class TestOwnerTransferApi: + def test_transfer_self(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + ): + with pytest.raises(CannotTransferOwnerToSelfError): + method(api, "1") + + def test_invalid_token(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), + ): + with pytest.raises(InvalidTokenError): + method(api, "2") + + def test_member_not_in_tenant(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + member = MagicMock() + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "a@test.com"}, + ), + patch("controllers.console.workspace.members.db.session.get", return_value=member), + patch("controllers.console.workspace.members.TenantService.is_member", return_value=False), + ): + with pytest.raises(MemberNotInTenantError): + method(api, "2") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py new file mode 100644 index 0000000000..af0c2c5594 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -0,0 +1,388 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic_core import ValidationError +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace.model_providers import ( + ModelProviderCredentialApi, + ModelProviderCredentialSwitchApi, + ModelProviderIconApi, + ModelProviderListApi, + ModelProviderPaymentCheckoutUrlApi, + ModelProviderValidateApi, + PreferredProviderTypeUpdateApi, +) +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError + +VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" +INVALID_UUID = "123" + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestModelProviderListApi: + def test_get_success(self, app): + api = ModelProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?model_type=llm"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_provider_list", + return_value=[{"name": "openai"}], + ), + ): + result = method(api) + + assert "data" in result + + +class TestModelProviderCredentialApi: + def test_get_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context(f"/?credential_id={VALID_UUID}"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential", + return_value={"key": "value"}, + ), + ): + result = method(api, provider="openai") + + assert "credentials" in result + + def test_get_invalid_uuid(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context(f"/?credential_id={INVALID_UUID}"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + def test_post_create_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}, "name": "test"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", + return_value=None, + ), + ): + result, status = method(api, provider="openai") + + assert result["result"] == "success" + assert status == 201 + + def test_post_create_validation_error(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", + side_effect=CredentialsValidateFailedError("bad"), + ), + ): + with pytest.raises(ValueError): + method(api, provider="openai") + + def test_put_update_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.put) + + payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_put_invalid_uuid(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.put) + + payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + def test_delete_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.delete) + + payload = {"credential_id": VALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential", + return_value=None, + ), + ): + result, status = method(api, provider="openai") + + assert result["result"] == "success" + assert status == 204 + + +class TestModelProviderCredentialSwitchApi: + def test_switch_success(self, app): + api = ModelProviderCredentialSwitchApi() + method = unwrap(api.post) + + payload = {"credential_id": VALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_switch_invalid_uuid(self, app): + api = ModelProviderCredentialSwitchApi() + method = unwrap(api.post) + + payload = {"credential_id": INVALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + +class TestModelProviderValidateApi: + def test_validate_success(self, app): + api = ModelProviderValidateApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_validate_failure(self, app): + api = ModelProviderValidateApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", + side_effect=CredentialsValidateFailedError("bad"), + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "error" + + +class TestModelProviderIconApi: + def test_icon_success(self, app): + api = ModelProviderIconApi() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon", + return_value=(b"123", "image/png"), + ), + ): + response = api.get("t1", "openai", "logo", "en") + + assert response.mimetype == "image/png" + + def test_icon_not_found(self, app): + api = ModelProviderIconApi() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon", + return_value=(None, None), + ), + ): + with pytest.raises(ValueError): + api.get("t1", "openai", "logo", "en") + + +class TestPreferredProviderTypeUpdateApi: + def test_update_success(self, app): + api = PreferredProviderTypeUpdateApi() + method = unwrap(api.post) + + payload = {"preferred_provider_type": "custom"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_invalid_enum(self, app): + api = PreferredProviderTypeUpdateApi() + method = unwrap(api.post) + + payload = {"preferred_provider_type": "invalid"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + +class TestModelProviderPaymentCheckoutUrlApi: + def test_checkout_success(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + user = MagicMock(id="u1", email="x@test.com") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(user, "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", + return_value=None, + ), + patch( + "controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link", + return_value={"url": "x"}, + ), + ): + result = method(api, provider="anthropic") + + assert "url" in result + + def test_invalid_provider(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(ValueError): + method(api, provider="openai") + + def test_permission_denied(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + user = MagicMock(id="u1", email="x@test.com") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(user, "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", + side_effect=Forbidden(), + ), + ): + with pytest.raises(Forbidden): + method(api, provider="anthropic") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py new file mode 100644 index 0000000000..43b8e1ac2e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -0,0 +1,447 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.workspace.models import ( + DefaultModelApi, + ModelProviderAvailableModelApi, + ModelProviderModelApi, + ModelProviderModelCredentialApi, + ModelProviderModelCredentialSwitchApi, + ModelProviderModelDisableApi, + ModelProviderModelEnableApi, + ModelProviderModelParameterRuleApi, + ModelProviderModelValidateApi, +) +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDefaultModelApi: + def test_get_success(self, app: Flask): + api = DefaultModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context( + "/", + query_string={"model_type": ModelType.LLM.value}, + ), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"} + + result = method(api) + + assert "data" in result + + def test_post_success(self, app: Flask): + api = DefaultModelApi() + method = unwrap(api.post) + + payload = { + "model_settings": [ + { + "model_type": ModelType.LLM.value, + "provider": "openai", + "model": "gpt-4", + } + ] + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api) + + assert result["result"] == "success" + + def test_get_returns_empty_when_no_default(self, app): + api = DefaultModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_default_model_of_model_type.return_value = None + + result = method(api) + + assert "data" in result + + +class TestModelProviderModelApi: + def test_get_models_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_models_by_provider.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + def test_post_models_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "load_balancing": { + "configs": [{"weight": 1}], + "enabled": True, + }, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + patch("controllers.console.workspace.models.ModelLoadBalancingService"), + ): + result, status = method(api, "openai") + + assert status == 200 + + def test_delete_model_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.delete) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 204 + + def test_get_models_returns_empty(self, app): + api = ModelProviderModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_models_by_provider.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + +class TestModelProviderModelCredentialApi: + def test_get_credentials_success(self, app: Flask): + api = ModelProviderModelCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context( + "/", + query_string={ + "model": "gpt-4", + "model_type": ModelType.LLM.value, + }, + ), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as provider_service, + patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service, + ): + provider_service.return_value.get_model_credential.return_value = { + "credentials": {}, + "current_credential_id": None, + "current_credential_name": None, + } + provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] + lb_service.return_value.get_load_balancing_configs.return_value = (False, []) + + result = method(api, "openai") + + assert "credentials" in result + + def test_create_credential_success(self, app: Flask): + api = ModelProviderModelCredentialApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credentials": {"key": "val"}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 201 + + def test_get_empty_credentials(self, app): + api = ModelProviderModelCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb, + ): + service.return_value.get_model_credential.return_value = None + service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] + lb.return_value.get_load_balancing_configs.return_value = (False, []) + + result = method(api, "openai") + + assert result["credentials"] == {} + + def test_delete_success(self, app): + api = ModelProviderModelCredentialApi() + method = unwrap(api.delete) + + payload = { + "model": "gpt", + "model_type": ModelType.LLM.value, + "credential_id": "123e4567-e89b-12d3-a456-426614174000", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 204 + + +class TestModelProviderModelCredentialSwitchApi: + def test_switch_success(self, app: Flask): + api = ModelProviderModelCredentialSwitchApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credential_id": "abc", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + +class TestModelEnableDisableApis: + def test_enable_model(self, app: Flask): + api = ModelProviderModelEnableApi() + method = unwrap(api.patch) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + def test_disable_model(self, app: Flask): + api = ModelProviderModelDisableApi() + method = unwrap(api.patch) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + +class TestModelProviderModelValidateApi: + def test_validate_success(self, app: Flask): + api = ModelProviderModelValidateApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credentials": {"key": "val"}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + @pytest.mark.parametrize("model_name", ["gpt-4", "gpt"]) + def test_validate_failure(self, app: Flask, model_name: str): + api = ModelProviderModelValidateApi() + method = unwrap(api.post) + + payload = { + "model": model_name, + "model_type": ModelType.LLM.value, + "credentials": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid") + + result = method(api, "openai") + + assert result["result"] == "error" + + +class TestParameterAndAvailableModels: + def test_parameter_rules(self, app: Flask): + api = ModelProviderModelParameterRuleApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt-4"}), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_model_parameter_rules.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + def test_available_models(self, app: Flask): + api = ModelProviderAvailableModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_models_by_model_type.return_value = [] + + result = method(api, ModelType.LLM.value) + + assert "data" in result + + def test_empty_rules(self, app): + api = ModelProviderModelParameterRuleApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt"}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_model_parameter_rules.return_value = [] + + result = method(api, "openai") + + assert result["data"] == [] + + def test_no_models(self, app): + api = ModelProviderAvailableModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_models_by_model_type.return_value = [] + + result = method(api, ModelType.LLM.value) + + assert result["data"] == [] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py new file mode 100644 index 0000000000..f6db55db5b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -0,0 +1,1019 @@ +import io +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace.plugin import ( + PluginAssetApi, + PluginAutoUpgradeExcludePluginApi, + PluginChangePermissionApi, + PluginChangePreferencesApi, + PluginDebuggingKeyApi, + PluginDeleteAllInstallTaskItemsApi, + PluginDeleteInstallTaskApi, + PluginDeleteInstallTaskItemApi, + PluginFetchDynamicSelectOptionsApi, + PluginFetchDynamicSelectOptionsWithCredentialsApi, + PluginFetchInstallTaskApi, + PluginFetchInstallTasksApi, + PluginFetchManifestApi, + PluginFetchMarketplacePkgApi, + PluginFetchPermissionApi, + PluginFetchPreferencesApi, + PluginIconApi, + PluginInstallFromGithubApi, + PluginInstallFromMarketplaceApi, + PluginInstallFromPkgApi, + PluginListApi, + PluginListInstallationsFromIdsApi, + PluginListLatestVersionsApi, + PluginReadmeApi, + PluginUninstallApi, + PluginUpgradeFromGithubApi, + PluginUpgradeFromMarketplaceApi, + PluginUploadFromBundleApi, + PluginUploadFromGithubApi, + PluginUploadFromPkgApi, +) +from core.plugin.impl.exc import PluginDaemonClientSideError +from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user(): + u = MagicMock() + u.id = "u1" + u.is_admin_or_owner = True + return u + + +@pytest.fixture +def tenant(): + return "t1" + + +class TestPluginListLatestVersionsApi: + def test_success(self, app): + api = PluginListLatestVersionsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.plugin.PluginService.list_latest_versions", return_value={"p1": "1.0"} + ), + ): + result = method(api) + + assert "versions" in result + + def test_daemon_error(self, app): + api = PluginListLatestVersionsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.plugin.PluginService.list_latest_versions", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginDebuggingKeyApi: + def test_debugging_key_success(self, app): + api = PluginDebuggingKeyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.get_debugging_key", return_value="k"), + ): + result = method(api) + + assert result["key"] == "k" + + def test_debugging_key_error(self, app): + api = PluginDebuggingKeyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.get_debugging_key", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginListApi: + def test_plugin_list(self, app): + api = PluginListApi() + method = unwrap(api.get) + + mock_list = MagicMock(list=[{"id": 1}], total=1) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.list_with_total", return_value=mock_list), + ): + result = method(api) + + assert result["total"] == 1 + + +class TestPluginIconApi: + def test_plugin_icon(self, app): + api = PluginIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?tenant_id=t1&filename=a.png"), + patch("controllers.console.workspace.plugin.PluginService.get_asset", return_value=(b"x", "image/png")), + ): + response = method(api) + + assert response.mimetype == "image/png" + + +class TestPluginAssetApi: + def test_plugin_asset(self, app): + api = PluginAssetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p&file_name=a.bin"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.extract_asset", return_value=b"x"), + ): + response = method(api) + + assert response.mimetype == "application/octet-stream" + + +class TestPluginUploadFromPkgApi: + def test_upload_pkg_success(self, app): + api = PluginUploadFromPkgApi() + method = unwrap(api.post) + + data = { + "pkg": (io.BytesIO(b"x"), "test.pkg"), + } + + with ( + app.test_request_context("/", data=data, content_type="multipart/form-data"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.upload_pkg", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_upload_pkg_too_large(self, app): + api = PluginUploadFromPkgApi() + method = unwrap(api.post) + + data = { + "pkg": (io.BytesIO(b"x"), "test.pkg"), + } + + with ( + app.test_request_context("/", data=data, content_type="multipart/form-data"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginInstallFromPkgApi: + def test_install_from_pkg(self, app): + api = PluginInstallFromPkgApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_local_pkg", return_value={"ok": True} + ), + ): + result = method(api) + + assert result["ok"] is True + + +class TestPluginUninstallApi: + def test_uninstall(self, app): + api = PluginUninstallApi() + method = unwrap(api.post) + + payload = {"plugin_installation_id": "x"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.uninstall", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + +class TestPluginChangePermissionApi: + def test_change_permission_forbidden(self, app): + api = PluginChangePermissionApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=False) + + payload = { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(Forbidden): + method(api) + + def test_change_permission_success(self, app): + api = PluginChangePermissionApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + +class TestPluginFetchPermissionApi: + def test_fetch_permission_default(self, app): + api = PluginFetchPermissionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=None), + ): + result = method(api) + + assert result["install_permission"] is not None + + +class TestPluginFetchDynamicSelectOptionsApi: + def test_fetch_dynamic_options(self, app, user): + api = PluginFetchDynamicSelectOptionsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_id=p&provider=x&action=y¶meter=z&provider_type=tool"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options", + return_value=[1, 2], + ), + ): + result = method(api) + + assert result["options"] == [1, 2] + + +class TestPluginReadmeApi: + def test_fetch_readme(self, app): + api = PluginReadmeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_readme", return_value="readme"), + ): + result = method(api) + + assert result["readme"] == "readme" + + +class TestPluginListInstallationsFromIdsApi: + def test_success(self, app): + api = PluginListInstallationsFromIdsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1", "p2"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", + return_value=[{"id": "p1"}], + ), + ): + result = method(api) + + assert "plugins" in result + + def test_daemon_error(self, app): + api = PluginListInstallationsFromIdsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUploadFromGithubApi: + def test_success(self, app): + api = PluginUploadFromGithubApi() + method = unwrap(api.post) + + payload = {"repo": "r", "version": "v", "package": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", return_value={"ok": True} + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUploadFromGithubApi() + method = unwrap(api.post) + + payload = {"repo": "r", "version": "v", "package": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUploadFromBundleApi: + def test_success(self, app): + api = PluginUploadFromBundleApi() + method = unwrap(api.post) + + file = FileStorage( + stream=io.BytesIO(b"x"), + filename="test.bundle", + content_type="application/octet-stream", + ) + + with ( + app.test_request_context( + "/", + data={"bundle": file}, + content_type="multipart/form-data", + ), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.upload_bundle", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_too_large(self, app): + api = PluginUploadFromBundleApi() + method = unwrap(api.post) + + file = FileStorage( + stream=io.BytesIO(b"x"), + filename="test.bundle", + content_type="application/octet-stream", + ) + + with ( + app.test_request_context( + "/", + data={"bundle": file}, + content_type="multipart/form-data", + ), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginInstallFromGithubApi: + def test_success(self, app): + api = PluginInstallFromGithubApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.install_from_github", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginInstallFromGithubApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginInstallFromMarketplaceApi: + def test_success(self, app): + api = PluginInstallFromMarketplaceApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginInstallFromMarketplaceApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchMarketplacePkgApi: + def test_success(self, app): + api = PluginFetchMarketplacePkgApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", return_value={"m": 1}), + ): + result = method(api) + + assert "manifest" in result + + def test_daemon_error(self, app): + api = PluginFetchMarketplacePkgApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchManifestApi: + def test_success(self, app): + api = PluginFetchManifestApi() + method = unwrap(api.get) + + manifest = MagicMock() + manifest.model_dump.return_value = {"x": 1} + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", return_value=manifest), + ): + result = method(api) + + assert "manifest" in result + + def test_daemon_error(self, app): + api = PluginFetchManifestApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchInstallTasksApi: + def test_success(self, app): + api = PluginFetchInstallTasksApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_install_tasks", return_value=[{"id": 1}]), + ): + result = method(api) + + assert "tasks" in result + + def test_daemon_error(self, app): + api = PluginFetchInstallTasksApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_install_tasks", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchInstallTaskApi: + def test_success(self, app): + api = PluginFetchInstallTaskApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_install_task", return_value={"id": "x"}), + ): + result = method(api, "x") + + assert "task" in result + + def test_daemon_error(self, app): + api = PluginFetchInstallTaskApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_install_task", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "t") + + +class TestPluginDeleteInstallTaskApi: + def test_success(self, app): + api = PluginDeleteInstallTaskApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.delete_install_task", return_value=True), + ): + result = method(api, "x") + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteInstallTaskApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_install_task", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "t") + + +class TestPluginDeleteAllInstallTaskItemsApi: + def test_success(self, app): + api = PluginDeleteAllInstallTaskItemsApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", return_value=True + ), + ): + result = method(api) + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteAllInstallTaskItemsApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginDeleteInstallTaskItemApi: + def test_success(self, app): + api = PluginDeleteInstallTaskItemApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.delete_install_task_item", return_value=True), + ): + result = method(api, "task1", "item1") + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteInstallTaskItemApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_install_task_item", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "task1", "item1") + + +class TestPluginUpgradeFromMarketplaceApi: + def test_success(self, app): + api = PluginUpgradeFromMarketplaceApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUpgradeFromMarketplaceApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUpgradeFromGithubApi: + def test_success(self, app): + api = PluginUpgradeFromGithubApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUpgradeFromGithubApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: + def test_success(self, app): + api = PluginFetchDynamicSelectOptionsWithCredentialsApi() + method = unwrap(api.post) + + user = MagicMock(id="u1", is_admin_or_owner=True) + + payload = { + "plugin_id": "p", + "provider": "x", + "action": "y", + "parameter": "z", + "credential_id": "c", + "credentials": {"k": "v"}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", + return_value=[1], + ), + ): + result = method(api) + + assert result["options"] == [1] + + def test_daemon_error(self, app): + api = PluginFetchDynamicSelectOptionsWithCredentialsApi() + method = unwrap(api.post) + + user = MagicMock(id="u1", is_admin_or_owner=True) + + payload = { + "plugin_id": "p", + "provider": "x", + "action": "y", + "parameter": "z", + "credential_id": "c", + "credentials": {"k": "v"}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginChangePreferencesApi: + def test_success(self, app): + api = PluginChangePreferencesApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "permission": { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + }, + "auto_upgrade": { + "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + "upgrade_time_of_day": 0, + "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + "exclude_plugins": [], + "include_plugins": [], + }, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.change_strategy", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_permission_fail(self, app): + api = PluginChangePreferencesApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "permission": { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + }, + "auto_upgrade": { + "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + "upgrade_time_of_day": 0, + "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + "exclude_plugins": [], + "include_plugins": [], + }, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +class TestPluginFetchPreferencesApi: + def test_success(self, app): + api = PluginFetchPreferencesApi() + method = unwrap(api.get) + + permission = MagicMock( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + + auto_upgrade = MagicMock( + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=1, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + exclude_plugins=[], + include_plugins=[], + ) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=permission + ), + patch( + "controllers.console.workspace.plugin.PluginAutoUpgradeService.get_strategy", return_value=auto_upgrade + ), + ): + result = method(api) + + assert "permission" in result + assert "auto_upgrade" in result + + +class TestPluginAutoUpgradeExcludePluginApi: + def test_success(self, app): + api = PluginAutoUpgradeExcludePluginApi() + method = unwrap(api.post) + + payload = {"plugin_id": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_fail(self, app): + api = PluginAutoUpgradeExcludePluginApi() + method = unwrap(api.post) + + payload = {"plugin_id": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=False), + ): + result = method(api) + + assert result["success"] is False diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py index b15676d9b7..16ea1bf509 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py @@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_restx import Api +from werkzeug.exceptions import Forbidden -from controllers.console.workspace.tool_providers import ToolProviderMCPApi +from controllers.console.workspace.tool_providers import ( + ToolApiListApi, + ToolApiProviderAddApi, + ToolApiProviderDeleteApi, + ToolApiProviderGetApi, + ToolApiProviderGetRemoteSchemaApi, + ToolApiProviderListToolsApi, + ToolApiProviderUpdateApi, + ToolBuiltinListApi, + ToolBuiltinProviderAddApi, + ToolBuiltinProviderCredentialsSchemaApi, + ToolBuiltinProviderDeleteApi, + ToolBuiltinProviderGetCredentialInfoApi, + ToolBuiltinProviderGetCredentialsApi, + ToolBuiltinProviderGetOauthClientSchemaApi, + ToolBuiltinProviderIconApi, + ToolBuiltinProviderInfoApi, + ToolBuiltinProviderListToolsApi, + ToolBuiltinProviderSetDefaultApi, + ToolBuiltinProviderUpdateApi, + ToolLabelsApi, + ToolOAuthCallback, + ToolOAuthCustomClient, + ToolPluginOAuthApi, + ToolProviderListApi, + ToolProviderMCPApi, + ToolWorkflowListApi, + ToolWorkflowProviderCreateApi, + ToolWorkflowProviderDeleteApi, + ToolWorkflowProviderGetApi, + ToolWorkflowProviderUpdateApi, + is_valid_url, +) from core.db.session_factory import configure_session_factory from extensions.ext_database import db from services.tools.mcp_tools_manage_service import ReconnectResult -# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file. -# They are intentionally no-ops because the test already patches the required -# behaviors explicitly via @patch and context managers below. +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + @pytest.fixture def _mock_cache(): return @@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_ # 若 transform 后包含 tools 字段,确保非空 assert isinstance(body.get("tools"), list) assert body["tools"] + + +class TestUtils: + def test_is_valid_url(self): + assert is_valid_url("https://example.com") + assert is_valid_url("http://example.com") + assert not is_valid_url("") + assert not is_valid_url("ftp://example.com") + assert not is_valid_url("not-a-url") + assert not is_valid_url(None) + + +class TestToolProviderListApi: + def test_get_success(self, app): + api = ToolProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers", + return_value=["p1"], + ), + ): + assert method(api) == ["p1"] + + +class TestBuiltinProviderApis: + def test_list_tools(self, app): + api = ToolBuiltinProviderListToolsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools", + return_value=[{"a": 1}], + ), + ): + assert method(api, "provider") == [{"a": 1}] + + def test_info(self, app): + api = ToolBuiltinProviderInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info", + return_value={"x": 1}, + ), + ): + assert method(api, "provider") == {"x": 1} + + def test_delete(self, app): + api = ToolBuiltinProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credential_id": "cid"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider", + return_value={"result": "success"}, + ), + ): + assert method(api, "provider")["result"] == "success" + + def test_add_invalid_type(self, app): + api = ToolBuiltinProviderAddApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}, "type": "invalid"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + ): + with pytest.raises(ValueError): + method(api, "provider") + + def test_add_success(self, app): + api = ToolBuiltinProviderAddApi() + method = unwrap(api.post) + + payload = {"credentials": {}, "type": "oauth2", "name": "n"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider", + return_value={"id": 1}, + ), + ): + assert method(api, "provider")["id"] == 1 + + def test_update(self, app): + api = ToolBuiltinProviderUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "c1", "credentials": {}, "name": "n"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_credentials(self, app): + api = ToolBuiltinProviderGetCredentialsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials", + return_value={"k": "v"}, + ), + ): + assert method(api, "provider") == {"k": "v"} + + def test_icon(self, app): + api = ToolBuiltinProviderIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon", + return_value=(b"x", "image/png"), + ), + ): + response = method(api, "provider") + assert response.mimetype == "image/png" + + def test_credentials_schema(self, app): + api = ToolBuiltinProviderCredentialsSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema", + return_value={"schema": {}}, + ), + ): + assert method(api, "provider", "oauth2") == {"schema": {}} + + def test_set_default_credential(self, app): + api = ToolBuiltinProviderSetDefaultApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"id": "c1"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_credential_info(self, app): + api = ToolBuiltinProviderGetCredentialInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info", + return_value={"info": "x"}, + ), + ): + assert method(api, "provider") == {"info": "x"} + + def test_get_oauth_client_schema(self, app): + api = ToolBuiltinProviderGetOauthClientSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema", + return_value={"schema": {}}, + ), + ): + assert method(api, "provider") == {"schema": {}} + + +class TestApiProviderApis: + def test_add(self, app): + api = ToolApiProviderAddApi() + method = unwrap(api.post) + + payload = { + "credentials": {}, + "schema_type": "openapi", + "schema": "{}", + "provider": "p", + "icon": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider", + return_value={"id": 1}, + ), + ): + assert method(api)["id"] == 1 + + def test_remote_schema(self, app): + api = ToolApiProviderGetRemoteSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?url=http://x.com"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema", + return_value={"schema": "x"}, + ), + ): + assert method(api)["schema"] == "x" + + def test_list_tools(self, app): + api = ToolApiProviderListToolsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?provider=p"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools", + return_value=[{"tool": 1}], + ), + ): + assert method(api) == [{"tool": 1}] + + def test_update(self, app): + api = ToolApiProviderUpdateApi() + method = unwrap(api.post) + + payload = { + "credentials": {}, + "schema_type": "openapi", + "schema": "{}", + "provider": "p", + "original_provider": "o", + "icon": {}, + "privacy_policy": "", + "custom_disclaimer": "", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider", + return_value={"ok": True}, + ), + ): + assert method(api)["ok"] + + def test_delete(self, app): + api = ToolApiProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"provider": "p"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider", + return_value={"result": "success"}, + ), + ): + assert method(api)["result"] == "success" + + def test_get(self, app): + api = ToolApiProviderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?provider=p"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider", + return_value={"x": 1}, + ), + ): + assert method(api) == {"x": 1} + + +class TestWorkflowApis: + def test_create(self, app): + api = ToolWorkflowProviderCreateApi() + method = unwrap(api.post) + + payload = { + "workflow_app_id": "123e4567-e89b-12d3-a456-426614174000", + "name": "n", + "label": "l", + "description": "d", + "icon": {}, + "parameters": [], + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool", + return_value={"id": 1}, + ), + ): + assert method(api)["id"] == 1 + + def test_update_invalid(self, app): + api = ToolWorkflowProviderUpdateApi() + method = unwrap(api.post) + + payload = { + "workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000", + "name": "Tool", + "label": "Tool Label", + "description": "A tool", + "icon": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool", + return_value={"ok": True}, + ), + ): + result = method(api) + assert result["ok"] + + def test_delete(self, app): + api = ToolWorkflowProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool", + return_value={"ok": True}, + ), + ): + assert method(api)["ok"] + + def test_get_error(self, app): + api = ToolWorkflowProviderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestLists: + def test_builtin_list(self, app): + api = ToolBuiltinListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + def test_api_list(self, app): + api = ToolApiListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + def test_workflow_list(self, app): + api = ToolWorkflowListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + +class TestLabels: + def test_labels(self, app): + api = ToolLabelsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels", + return_value=["l1"], + ), + ): + assert method(api) == ["l1"] + + +class TestOAuth: + def test_oauth_no_client(self, app): + api = ToolPluginOAuthApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "provider") + + def test_oauth_callback_no_cookie(self, app): + api = ToolOAuthCallback() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "provider") + + +class TestOAuthCustomClient: + def test_save_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"client_params": {"a": 1}}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params", + return_value={"client_id": "x"}, + ), + ): + assert method(api, "provider") == {"client_id": "x"} + + def test_delete_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py new file mode 100644 index 0000000000..4776bc7af0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py @@ -0,0 +1,558 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden + +from controllers.console.workspace.trigger_providers import ( + TriggerOAuthAuthorizeApi, + TriggerOAuthCallbackApi, + TriggerOAuthClientManageApi, + TriggerProviderIconApi, + TriggerProviderInfoApi, + TriggerProviderListApi, + TriggerSubscriptionBuilderBuildApi, + TriggerSubscriptionBuilderCreateApi, + TriggerSubscriptionBuilderGetApi, + TriggerSubscriptionBuilderLogsApi, + TriggerSubscriptionBuilderUpdateApi, + TriggerSubscriptionBuilderVerifyApi, + TriggerSubscriptionDeleteApi, + TriggerSubscriptionListApi, + TriggerSubscriptionUpdateApi, + TriggerSubscriptionVerifyApi, +) +from controllers.web.error import NotFoundError +from core.plugin.entities.plugin_daemon import CredentialType +from models.account import Account + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def mock_user(): + user = MagicMock(spec=Account) + user.id = "u1" + user.current_tenant_id = "t1" + return user + + +class TestTriggerProviderApis: + def test_icon_success(self, app): + api = TriggerProviderIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon", + return_value="icon", + ), + ): + assert method(api, "github") == "icon" + + def test_list_providers(self, app): + api = TriggerProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers", + return_value=[], + ), + ): + assert method(api) == [] + + def test_provider_info(self, app): + api = TriggerProviderInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider", + return_value={"id": "p1"}, + ), + ): + assert method(api, "github") == {"id": "p1"} + + +class TestTriggerSubscriptionListApi: + def test_list_success(self, app): + api = TriggerSubscriptionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", + return_value=[], + ), + ): + assert method(api, "github") == [] + + def test_list_invalid_provider(self, app): + api = TriggerSubscriptionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", + side_effect=ValueError("bad"), + ), + ): + result, status = method(api, "bad") + assert status == 404 + + +class TestTriggerSubscriptionBuilderApis: + def test_create_builder(self, app): + api = TriggerSubscriptionBuilderCreateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + return_value={"id": "b1"}, + ), + ): + result = method(api, "github") + assert "subscription_builder" in result + + def test_get_builder(self, app): + api = TriggerSubscriptionBuilderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", + return_value={"id": "b1"}, + ), + ): + assert method(api, "github", "b1") == {"id": "b1"} + + def test_verify_builder(self, app): + api = TriggerSubscriptionBuilderVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {"a": 1}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", + return_value={"ok": True}, + ), + ): + assert method(api, "github", "b1") == {"ok": True} + + def test_verify_builder_error(self, app): + api = TriggerSubscriptionBuilderVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", + side_effect=Exception("err"), + ), + ): + with pytest.raises(ValueError): + method(api, "github", "b1") + + def test_update_builder(self, app): + api = TriggerSubscriptionBuilderUpdateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "n"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", + return_value={"id": "b1"}, + ), + ): + assert method(api, "github", "b1") == {"id": "b1"} + + def test_logs(self, app): + api = TriggerSubscriptionBuilderLogsApi() + method = unwrap(api.get) + + log = MagicMock() + log.model_dump.return_value = {"a": 1} + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs", + return_value=[log], + ), + ): + assert "logs" in method(api, "github", "b1") + + def test_build(self, app): + api = TriggerSubscriptionBuilderBuildApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder", + return_value=None, + ), + ): + assert method(api, "github", "b1") == 200 + + +class TestTriggerSubscriptionCrud: + def test_update_rename_only(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + sub = MagicMock() + sub.provider_id = "github" + sub.credential_type = CredentialType.UNAUTHORIZED + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=sub, + ), + patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"), + ): + assert method(api, "s1") == 200 + + def test_update_not_found(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=None, + ), + ): + with pytest.raises(NotFoundError): + method(api, "x") + + def test_update_rebuild(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + sub = MagicMock() + sub.provider_id = "github" + sub.credential_type = CredentialType.OAUTH2 + sub.credentials = {} + sub.parameters = {} + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=sub, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription" + ), + ): + assert method(api, "s1") == 200 + + def test_delete_subscription(self, app): + api = TriggerSubscriptionDeleteApi() + method = unwrap(api.post) + + mock_session = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch("controllers.console.workspace.trigger_providers.db") as mock_db, + patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls, + patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription" + ), + ): + mock_db.engine = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + result = method(api, "sub1") + + assert result["result"] == "success" + + def test_delete_subscription_value_error(self, app): + api = TriggerSubscriptionDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch("controllers.console.workspace.trigger_providers.db") as mock_db, + patch("controllers.console.workspace.trigger_providers.Session") as session_cls, + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider", + side_effect=ValueError("bad"), + ), + ): + mock_db.engine = MagicMock() + session_cls.return_value.__enter__.return_value = MagicMock() + + with pytest.raises(BadRequest): + method(api, "sub1") + + +class TestTriggerOAuthApis: + def test_oauth_authorize_success(self, app): + api = TriggerOAuthAuthorizeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + return_value=MagicMock(id="b1"), + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context", + return_value="ctx", + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url", + return_value=MagicMock(authorization_url="url"), + ), + ): + resp = method(api, "github") + assert resp.status_code == 200 + + def test_oauth_authorize_no_client(self, app): + api = TriggerOAuthAuthorizeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(NotFoundError): + method(api, "github") + + def test_oauth_callback_forbidden(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "github") + + def test_oauth_callback_success(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials", + return_value=MagicMock(credentials={"a": 1}, expires_at=1), + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder" + ), + ): + resp = method(api, "github") + assert resp.status_code == 302 + + def test_oauth_callback_no_oauth_client(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", + return_value=ctx, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "github") + + def test_oauth_callback_empty_credentials(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", + return_value=ctx, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials", + return_value=MagicMock(credentials=None, expires_at=None), + ), + ): + with pytest.raises(ValueError): + method(api, "github") + + +class TestTriggerOAuthClientManageApi: + def test_get_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params", + return_value={}, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled", + return_value=False, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists", + return_value=True, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider", + return_value=MagicMock(get_oauth_client_schema=lambda: {}), + ), + ): + result = method(api, "github") + assert "configured" in result + + def test_post_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"enabled": True}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "github") == {"ok": True} + + def test_delete_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "github") == {"ok": True} + + def test_oauth_client_post_value_error(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"enabled": True}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", + side_effect=ValueError("bad"), + ), + ): + with pytest.raises(BadRequest): + method(api, "github") + + +class TestTriggerSubscriptionVerifyApi: + def test_verify_success(self, app): + api = TriggerSubscriptionVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", + return_value={"ok": True}, + ), + ): + assert method(api, "github", "s1") == {"ok": True} + + @pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")]) + def test_verify_errors(self, app, raised_exception): + api = TriggerSubscriptionVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", + side_effect=raised_exception, + ), + ): + with pytest.raises(BadRequest): + method(api, "github", "s1") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py new file mode 100644 index 0000000000..06f666fa60 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -0,0 +1,605 @@ +from datetime import datetime +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import Unauthorized + +import services +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from controllers.console.error import AccountNotLinkTenantError +from controllers.console.workspace.workspace import ( + CustomConfigWorkspaceApi, + SwitchWorkspaceApi, + TenantApi, + TenantListApi, + WebappLogoWorkspaceApi, + WorkspaceInfoApi, + WorkspaceListApi, + WorkspacePermissionApi, +) +from enums.cloud_plan import CloudPlan +from models.account import TenantStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestTenantListApi: + def test_get_success(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + features = MagicMock() + features.billing.enabled = True + features.billing.subscription.plan = CloudPlan.SANDBOX + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features), + ): + result, status = method(api) + + assert status == 200 + assert len(result["workspaces"]) == 2 + assert result["workspaces"][0]["current"] is True + + def test_get_billing_disabled(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant = MagicMock( + id="t1", + name="Tenant", + status="active", + created_at=datetime.utcnow(), + ) + + features = MagicMock() + features.billing.enabled = False + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant], + ), + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + + +class TestWorkspaceListApi: + def test_get_success(self, app): + api = WorkspaceListApi() + method = unwrap(api.get) + + tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow()) + + paginate_result = MagicMock( + items=[tenant], + has_next=False, + total=1, + ) + + with ( + app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}), + patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result), + ): + result, status = method(api) + + assert status == 200 + assert result["total"] == 1 + assert result["has_more"] is False + + def test_get_has_next_true(self, app): + api = WorkspaceListApi() + method = unwrap(api.get) + + tenant = MagicMock( + id="t1", + name="T", + status="active", + created_at=datetime.utcnow(), + ) + + paginate_result = MagicMock( + items=[tenant], + has_next=True, + total=10, + ) + + with ( + app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}), + patch( + "controllers.console.workspace.workspace.db.paginate", + return_value=paginate_result, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["has_more"] is True + + +class TestTenantApi: + def test_post_active_tenant(self, app): + api = TenantApi() + method = unwrap(api.post) + + tenant = MagicMock(status="active") + + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} + ), + ): + result, status = method(api) + + assert status == 200 + assert result["id"] == "t1" + + def test_post_archived_with_switch(self, app): + api = TenantApi() + method = unwrap(api.post) + + archived = MagicMock(status=TenantStatus.ARCHIVE) + new_tenant = MagicMock(status="active") + + user = MagicMock(current_tenant=archived) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"} + ), + ): + result, status = method(api) + + assert result["id"] == "new" + + def test_post_archived_no_tenant(self, app): + api = TenantApi() + method = unwrap(api.post) + + user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE)) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]), + ): + with pytest.raises(Unauthorized): + method(api) + + def test_post_info_path(self, app): + api = TenantApi() + method = unwrap(api.post) + + tenant = MagicMock(status="active") + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/info"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(user, "t1"), + ), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"id": "t1"}, + ), + patch("controllers.console.workspace.workspace.logger.warning") as warn_mock, + ): + result, status = method(api) + + warn_mock.assert_called_once() + assert status == 200 + + +class TestSwitchWorkspaceApi: + def test_switch_success(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "t2"} + tenant = MagicMock(id="t2") + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"} + ), + ): + query_mock.return_value.get.return_value = tenant + result = method(api) + + assert result["result"] == "success" + + def test_switch_not_linked(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "bad"} + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception), + ): + with pytest.raises(AccountNotLinkTenantError): + method(api) + + def test_switch_tenant_not_found(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "missing"} + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + ): + query_mock.return_value.get.return_value = None + + with pytest.raises(ValueError): + method(api) + + +class TestCustomConfigWorkspaceApi: + def test_post_success(self, app): + api = CustomConfigWorkspaceApi() + method = unwrap(api.post) + + tenant = MagicMock(custom_config_dict={}) + + payload = {"remove_webapp_brand": True} + + with ( + app.test_request_context("/workspaces/custom-config", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_logo_fallback(self, app): + api = CustomConfigWorkspaceApi() + method = unwrap(api.post) + + tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"}) + + payload = {"remove_webapp_brand": False} + + with ( + app.test_request_context("/workspaces/custom-config", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch( + "controllers.console.workspace.workspace.db.get_or_404", + return_value=tenant, + ), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"id": "t1"}, + ), + ): + result = method(api) + + assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo" + assert result["result"] == "success" + + +class TestWebappLogoWorkspaceApi: + def test_no_file(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/upload", data={}), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + ): + with pytest.raises(NoFileUploadedError): + method(api) + + def test_too_many_files(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + data = { + "file": MagicMock(), + "extra": MagicMock(), + } + + with ( + app.test_request_context("/upload", data=data), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + ): + with pytest.raises(TooManyFilesError): + method(api) + + def test_invalid_extension(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = MagicMock(filename="test.txt") + + with ( + app.test_request_context("/upload", data={"file": file}), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + ): + with pytest.raises(UnsupportedFileTypeError): + method(api) + + def test_upload_success(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"data"), + filename="logo.png", + content_type="image/png", + ) + + upload = MagicMock(id="file1") + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.return_value = upload + + result, status = method(api) + + assert status == 201 + assert result["id"] == "file1" + + def test_filename_missing(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"data"), + filename="", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + ): + with pytest.raises(FilenameNotExistsError): + method(api) + + def test_file_too_large(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"x"), + filename="logo.png", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big") + + with pytest.raises(FileTooLargeError): + method(api) + + def test_service_unsupported_file(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"x"), + filename="logo.png", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError() + + with pytest.raises(UnsupportedFileTypeError): + method(api) + + +class TestWorkspaceInfoApi: + def test_post_success(self, app): + api = WorkspaceInfoApi() + method = unwrap(api.post) + + tenant = MagicMock() + + payload = {"name": "New Name"} + + with ( + app.test_request_context("/workspaces/info", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"name": "New Name"}, + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_no_current_tenant(self, app): + api = WorkspaceInfoApi() + method = unwrap(api.post) + + payload = {"name": "X"} + + with ( + app.test_request_context("/workspaces/info", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestWorkspacePermissionApi: + def test_get_success(self, app): + api = WorkspacePermissionApi() + method = unwrap(api.get) + + permission = MagicMock( + workspace_id="t1", + allow_member_invite=True, + allow_owner_transfer=False, + ) + + with ( + app.test_request_context("/permission"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission", + return_value=permission, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["workspace_id"] == "t1" + + def test_no_current_tenant(self, app): + api = WorkspacePermissionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/permission"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(ValueError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py new file mode 100644 index 0000000000..b290748155 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import importlib +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace import plugin_permission_required +from models.account import TenantPluginPermission + + +class _SessionStub: + def __init__(self, permission): + self._permission = permission + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *_args, **_kwargs): + return self + + def where(self, *_args, **_kwargs): + return self + + def first(self): + return self._permission + + +def _workspace_module(): + return importlib.import_module(plugin_permission_required.__module__) + + +def _patch_session(monkeypatch: pytest.MonkeyPatch, permission): + module = _workspace_module() + monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission)) + monkeypatch.setattr(module, "db", SimpleNamespace(engine=object())) + + +def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, None) + + @plugin_permission_required() + def handler(): + return "ok" + + assert handler() == "ok" + + +def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.NOBODY, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + assert handler() == "ok" + + +def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.NOBODY, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(debug_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.ADMINS, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(debug_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() diff --git a/api/tests/unit_tests/controllers/web/__init__.py b/api/tests/unit_tests/controllers/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/web/conftest.py b/api/tests/unit_tests/controllers/web/conftest.py new file mode 100644 index 0000000000..274d78c9cf --- /dev/null +++ b/api/tests/unit_tests/controllers/web/conftest.py @@ -0,0 +1,85 @@ +"""Shared fixtures for controllers.web unit tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask + + +@pytest.fixture +def app() -> Flask: + """Minimal Flask app for request contexts.""" + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class FakeSession: + """Stand-in for db.session that returns pre-seeded objects by model class name.""" + + def __init__(self, mapping: dict[str, Any] | None = None): + self._mapping: dict[str, Any] = mapping or {} + self._model_name: str | None = None + + def query(self, model: type) -> FakeSession: + self._model_name = model.__name__ + return self + + def where(self, *_args: object, **_kwargs: object) -> FakeSession: + return self + + def first(self) -> Any: + assert self._model_name is not None + return self._mapping.get(self._model_name) + + +class FakeDB: + """Minimal db stub exposing engine and session.""" + + def __init__(self, session: FakeSession | None = None): + self.session = session or FakeSession() + self.engine = object() + + +def make_app_model( + *, + app_id: str = "app-1", + tenant_id: str = "tenant-1", + mode: str = "chat", + enable_site: bool = True, + status: str = "normal", +) -> SimpleNamespace: + """Build a fake App model with common defaults.""" + tenant = SimpleNamespace( + id=tenant_id, + status="normal", + plan="basic", + custom_config_dict={}, + ) + return SimpleNamespace( + id=app_id, + tenant_id=tenant_id, + tenant=tenant, + mode=mode, + enable_site=enable_site, + status=status, + workflow=None, + app_model_config=None, + ) + + +def make_end_user( + *, + user_id: str = "end-user-1", + session_id: str = "session-1", + external_user_id: str = "ext-user-1", +) -> SimpleNamespace: + """Build a fake EndUser model with common defaults.""" + return SimpleNamespace( + id=user_id, + session_id=session_id, + external_user_id=external_user_id, + ) diff --git a/api/tests/unit_tests/controllers/web/test_app.py b/api/tests/unit_tests/controllers/web/test_app.py new file mode 100644 index 0000000000..ce7ae27188 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_app.py @@ -0,0 +1,165 @@ +"""Unit tests for controllers.web.app endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission +from controllers.web.error import AppUnavailableError + + +# --------------------------------------------------------------------------- +# AppParameterApi +# --------------------------------------------------------------------------- +class TestAppParameterApi: + def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None: + features_dict = {"opening_statement": "Hello"} + workflow = SimpleNamespace( + features_dict=features_dict, + user_input_form=lambda to_old_structure=False: [], + ) + app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"} + result = AppParameterApi().get(app_model, SimpleNamespace()) + + mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[]) + assert result == {"result": "ok"} + + def test_workflow_mode_uses_workflow(self, app: Flask) -> None: + features_dict = {} + workflow = SimpleNamespace( + features_dict=features_dict, + user_input_form=lambda to_old_structure=False: [{"var": "x"}], + ) + app_model = SimpleNamespace(mode="workflow", workflow=workflow) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {} + AppParameterApi().get(app_model, SimpleNamespace()) + + mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}]) + + def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None: + app_model = SimpleNamespace(mode="advanced-chat", workflow=None) + with app.test_request_context("/parameters"): + with pytest.raises(AppUnavailableError): + AppParameterApi().get(app_model, SimpleNamespace()) + + def test_standard_mode_uses_app_model_config(self, app: Flask) -> None: + config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"}) + app_model = SimpleNamespace(mode="chat", app_model_config=config) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {} + AppParameterApi().get(app_model, SimpleNamespace()) + + call_kwargs = mock_params.call_args + assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}] + + def test_standard_mode_no_config_raises(self, app: Flask) -> None: + app_model = SimpleNamespace(mode="chat", app_model_config=None) + with app.test_request_context("/parameters"): + with pytest.raises(AppUnavailableError): + AppParameterApi().get(app_model, SimpleNamespace()) + + +# --------------------------------------------------------------------------- +# AppMeta +# --------------------------------------------------------------------------- +class TestAppMeta: + @patch("controllers.web.app.AppService") + def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None: + mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}} + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context("/meta"): + result = AppMeta().get(app_model, SimpleNamespace()) + + assert result == {"tool_icons": {}} + + +# --------------------------------------------------------------------------- +# AppAccessMode +# --------------------------------------------------------------------------- +class TestAppAccessMode: + @patch("controllers.web.app.FeatureService.get_system_features") + def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + with app.test_request_context("/webapp/access-mode?appId=app-1"): + result = AppAccessMode().get() + + assert result == {"accessMode": "public"} + + @patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.app.FeatureService.get_system_features") + def test_returns_access_mode_with_app_id( + self, mock_features: MagicMock, mock_access: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_access.return_value = SimpleNamespace(access_mode="internal") + + with app.test_request_context("/webapp/access-mode?appId=app-1"): + result = AppAccessMode().get() + + assert result == {"accessMode": "internal"} + mock_access.assert_called_once_with("app-1") + + @patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id") + @patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.app.FeatureService.get_system_features") + def test_resolves_app_code_to_id( + self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_access.return_value = SimpleNamespace(access_mode="external") + + with app.test_request_context("/webapp/access-mode?appCode=code1"): + result = AppAccessMode().get() + + mock_resolve.assert_called_once_with("code1") + mock_access.assert_called_once_with("resolved-id") + assert result == {"accessMode": "external"} + + @patch("controllers.web.app.FeatureService.get_system_features") + def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + + with app.test_request_context("/webapp/access-mode"): + with pytest.raises(ValueError, match="appId or appCode"): + AppAccessMode().get() + + +# --------------------------------------------------------------------------- +# AppWebAuthPermission +# --------------------------------------------------------------------------- +class TestAppWebAuthPermission: + @patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None: + with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}): + result = AppWebAuthPermission().get() + + assert result == {"result": True} + + def test_raises_when_missing_app_id(self, app: Flask) -> None: + with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}): + with pytest.raises(ValueError, match="appId"): + AppWebAuthPermission().get() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py new file mode 100644 index 0000000000..01f34345aa --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -0,0 +1,135 @@ +"""Unit tests for controllers.web.audio endpoints.""" + +from __future__ import annotations + +from io import BytesIO +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.audio import AudioApi, TextApi +from controllers.web.error import ( + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1", external_user_id="ext-1") + + +# --------------------------------------------------------------------------- +# AudioApi (audio-to-text) +# --------------------------------------------------------------------------- +class TestAudioApi: + @patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"}) + def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + data = {"file": (BytesIO(b"fake-audio"), "test.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + result = AudioApi().post(_app_model(), _end_user()) + + assert result == {"text": "hello"} + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError()) + def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b""), "empty.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(NoAudioUploadedError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big")) + def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"big"), "big.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(AudioTooLargeError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError()) + def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"bad"), "bad.xyz")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(UnsupportedAudioTypeError): + AudioApi().post(_app_model(), _end_user()) + + @patch( + "controllers.web.audio.AudioService.transcript_asr", + side_effect=ProviderNotSupportSpeechToTextServiceError(), + ) + def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderNotSupportSpeechToTextError): + AudioApi().post(_app_model(), _end_user()) + + @patch( + "controllers.web.audio.AudioService.transcript_asr", + side_effect=ProviderTokenNotInitError(description="no token"), + ) + def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderNotInitializeError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError()) + def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderQuotaExceededError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError()) + def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + AudioApi().post(_app_model(), _end_user()) + + +# --------------------------------------------------------------------------- +# TextApi (text-to-audio) +# --------------------------------------------------------------------------- +class TestTextApi: + @patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes") + @patch("controllers.web.audio.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None: + mock_ns.payload = {"text": "hello", "voice": "alloy"} + + with app.test_request_context("/text-to-audio", method="POST"): + result = TextApi().post(_app_model(), _end_user()) + + assert result == "audio-bytes" + mock_tts.assert_called_once() + + @patch( + "controllers.web.audio.AudioService.transcript_tts", + side_effect=InvokeError(description="invoke failed"), + ) + @patch("controllers.web.audio.web_ns") + def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None: + mock_ns.payload = {"text": "hello"} + + with app.test_request_context("/text-to-audio", method="POST"): + with pytest.raises(CompletionRequestError): + TextApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py new file mode 100644 index 0000000000..e88bcf2ae6 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -0,0 +1,161 @@ +"""Unit tests for controllers.web.completion endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from controllers.web.error import ( + CompletionRequestError, + NotChatAppError, + NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# CompletionApi +# --------------------------------------------------------------------------- +class TestCompletionApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(NotCompletionAppError): + CompletionApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"}) + @patch("controllers.web.completion.AppGenerateService.generate") + @patch("controllers.web.completion.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "test"} + mock_gen.return_value = "response-obj" + + with app.test_request_context("/completion-messages", method="POST"): + result = CompletionApi().post(_completion_app(), _end_user()) + + assert result == {"answer": "hi"} + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=ProviderTokenNotInitError(description="not init"), + ) + @patch("controllers.web.completion.web_ns") + def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderNotInitializeError): + CompletionApi().post(_completion_app(), _end_user()) + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=QuotaExceededError(), + ) + @patch("controllers.web.completion.web_ns") + def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderQuotaExceededError): + CompletionApi().post(_completion_app(), _end_user()) + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=ModelCurrentlyNotSupportError(), + ) + @patch("controllers.web.completion.web_ns") + def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + CompletionApi().post(_completion_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# CompletionStopApi +# --------------------------------------------------------------------------- +class TestCompletionStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/completion-messages/task-1/stop", method="POST"): + with pytest.raises(NotCompletionAppError): + CompletionStopApi().post(_chat_app(), _end_user(), "task-1") + + @patch("controllers.web.completion.AppTaskService.stop_task") + def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None: + with app.test_request_context("/completion-messages/task-1/stop", method="POST"): + result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1") + + assert status == 200 + assert result == {"result": "success"} + + +# --------------------------------------------------------------------------- +# ChatApi +# --------------------------------------------------------------------------- +class TestChatApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/chat-messages", method="POST"): + with pytest.raises(NotChatAppError): + ChatApi().post(_completion_app(), _end_user()) + + @patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"}) + @patch("controllers.web.completion.AppGenerateService.generate") + @patch("controllers.web.completion.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "hi"} + mock_gen.return_value = "response" + + with app.test_request_context("/chat-messages", method="POST"): + result = ChatApi().post(_chat_app(), _end_user()) + + assert result == {"answer": "reply"} + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=InvokeError(description="rate limit"), + ) + @patch("controllers.web.completion.web_ns") + def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "x"} + + with app.test_request_context("/chat-messages", method="POST"): + with pytest.raises(CompletionRequestError): + ChatApi().post(_chat_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# ChatStopApi +# --------------------------------------------------------------------------- +class TestChatStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/chat-messages/task-1/stop", method="POST"): + with pytest.raises(NotChatAppError): + ChatStopApi().post(_completion_app(), _end_user(), "task-1") + + @patch("controllers.web.completion.AppTaskService.stop_task") + def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None: + with app.test_request_context("/chat-messages/task-1/stop", method="POST"): + result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1") + + assert status == 200 + assert result == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/web/test_conversation.py b/api/tests/unit_tests/controllers/web/test_conversation.py new file mode 100644 index 0000000000..e5adbbbf66 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_conversation.py @@ -0,0 +1,183 @@ +"""Unit tests for controllers.web.conversation endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from controllers.web.error import NotChatAppError +from services.errors.conversation import ConversationNotExistsError + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# ConversationListApi +# --------------------------------------------------------------------------- +class TestConversationListApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/conversations"): + with pytest.raises(NotChatAppError): + ConversationListApi().get(_completion_app(), _end_user()) + + @patch("controllers.web.conversation.WebConversationService.pagination_by_last_id") + @patch("controllers.web.conversation.db") + def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None: + conv_id = str(uuid4()) + conv = SimpleNamespace( + id=conv_id, + name="Test", + inputs={}, + status="normal", + introduction="", + created_at=1700000000, + updated_at=1700000000, + ) + mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv]) + mock_db.engine = "engine" + + session_mock = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + + with ( + app.test_request_context("/conversations?limit=20"), + patch("controllers.web.conversation.Session", return_value=session_ctx), + ): + result = ConversationListApi().get(_chat_app(), _end_user()) + + assert result["limit"] == 20 + assert result["has_more"] is False + + +# --------------------------------------------------------------------------- +# ConversationApi (delete) +# --------------------------------------------------------------------------- +class TestConversationApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}"): + with pytest.raises(NotChatAppError): + ConversationApi().delete(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.ConversationService.delete") + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}"): + result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) + + assert status == 204 + assert result["result"] == "success" + + @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) + def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}"): + with pytest.raises(NotFound, match="Conversation Not Exists"): + ConversationApi().delete(_chat_app(), _end_user(), c_id) + + +# --------------------------------------------------------------------------- +# ConversationRenameApi +# --------------------------------------------------------------------------- +class TestConversationRenameApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}): + with pytest.raises(NotChatAppError): + ConversationRenameApi().post(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.ConversationService.rename") + @patch("controllers.web.conversation.web_ns") + def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: + c_id = uuid4() + mock_ns.payload = {"name": "New Name", "auto_generate": False} + conv = SimpleNamespace( + id=str(c_id), + name="New Name", + inputs={}, + status="normal", + introduction="", + created_at=1700000000, + updated_at=1700000000, + ) + mock_rename.return_value = conv + + with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}): + result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id) + + assert result["name"] == "New Name" + + @patch( + "controllers.web.conversation.ConversationService.rename", + side_effect=ConversationNotExistsError(), + ) + @patch("controllers.web.conversation.web_ns") + def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: + c_id = uuid4() + mock_ns.payload = {"name": "X", "auto_generate": False} + + with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}): + with pytest.raises(NotFound, match="Conversation Not Exists"): + ConversationRenameApi().post(_chat_app(), _end_user(), c_id) + + +# --------------------------------------------------------------------------- +# ConversationPinApi / ConversationUnPinApi +# --------------------------------------------------------------------------- +class TestConversationPinApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"): + with pytest.raises(NotChatAppError): + ConversationPinApi().patch(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.WebConversationService.pin") + def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): + result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id) + + assert result["result"] == "success" + + @patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError()) + def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): + with pytest.raises(NotFound): + ConversationPinApi().patch(_chat_app(), _end_user(), c_id) + + +class TestConversationUnPinApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"): + with pytest.raises(NotChatAppError): + ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.WebConversationService.unpin") + def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"): + result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id) + + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/web/test_error.py b/api/tests/unit_tests/controllers/web/test_error.py new file mode 100644 index 0000000000..0387d002ba --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_error.py @@ -0,0 +1,75 @@ +"""Unit tests for controllers.web.error HTTP exception classes.""" + +from __future__ import annotations + +import pytest + +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + AppSuggestedQuestionsAfterAnswerDisabledError, + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + ConversationCompletedError, + InvalidArgumentError, + InvokeRateLimitError, + NoAudioUploadedError, + NotChatAppError, + NotCompletionAppError, + NotFoundError, + NotWorkflowAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, + WebAppAuthAccessDeniedError, + WebAppAuthRequiredError, + WebFormRateLimitExceededError, +) + +_ERROR_SPECS: list[tuple[type, str, int]] = [ + (AppUnavailableError, "app_unavailable", 400), + (NotCompletionAppError, "not_completion_app", 400), + (NotChatAppError, "not_chat_app", 400), + (NotWorkflowAppError, "not_workflow_app", 400), + (ConversationCompletedError, "conversation_completed", 400), + (ProviderNotInitializeError, "provider_not_initialize", 400), + (ProviderQuotaExceededError, "provider_quota_exceeded", 400), + (ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400), + (CompletionRequestError, "completion_request_error", 400), + (AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403), + (AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403), + (NoAudioUploadedError, "no_audio_uploaded", 400), + (AudioTooLargeError, "audio_too_large", 413), + (UnsupportedAudioTypeError, "unsupported_audio_type", 415), + (ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400), + (WebAppAuthRequiredError, "web_sso_auth_required", 401), + (WebAppAuthAccessDeniedError, "web_app_access_denied", 401), + (InvokeRateLimitError, "rate_limit_error", 429), + (WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429), + (NotFoundError, "not_found", 404), + (InvalidArgumentError, "invalid_param", 400), +] + + +@pytest.mark.parametrize( + ("cls", "expected_code", "expected_status"), + _ERROR_SPECS, + ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS], +) +def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None: + """Each error class exposes the correct error_code and HTTP status code.""" + assert cls.error_code == expected_code + assert cls.code == expected_status + + +def test_error_classes_have_description() -> None: + """Every error class has a description (string or None for generic errors).""" + # NotFoundError and InvalidArgumentError use None description by design + _NO_DESCRIPTION = {NotFoundError, InvalidArgumentError} + for cls, _, _ in _ERROR_SPECS: + if cls in _NO_DESCRIPTION: + continue + assert isinstance(cls.description, str), f"{cls.__name__} missing description" + assert len(cls.description) > 0, f"{cls.__name__} has empty description" diff --git a/api/tests/unit_tests/controllers/web/test_feature.py b/api/tests/unit_tests/controllers/web/test_feature.py new file mode 100644 index 0000000000..fe45d5f059 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_feature.py @@ -0,0 +1,38 @@ +"""Unit tests for controllers.web.feature endpoints.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from flask import Flask + +from controllers.web.feature import SystemFeatureApi + + +class TestSystemFeatureApi: + @patch("controllers.web.feature.FeatureService.get_system_features") + def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None: + mock_model = MagicMock() + mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}} + mock_features.return_value = mock_model + + with app.test_request_context("/system-features"): + result = SystemFeatureApi().get() + + assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}} + mock_features.assert_called_once() + + @patch("controllers.web.feature.FeatureService.get_system_features") + def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None: + """SystemFeatureApi is unauthenticated by design — no WebApiResource decorator.""" + mock_model = MagicMock() + mock_model.model_dump.return_value = {} + mock_features.return_value = mock_model + + # Verify it's a bare Resource, not WebApiResource + from flask_restx import Resource + + from controllers.web.wraps import WebApiResource + + assert issubclass(SystemFeatureApi, Resource) + assert not issubclass(SystemFeatureApi, WebApiResource) diff --git a/api/tests/unit_tests/controllers/web/test_files.py b/api/tests/unit_tests/controllers/web/test_files.py new file mode 100644 index 0000000000..a3921b0373 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_files.py @@ -0,0 +1,89 @@ +"""Unit tests for controllers.web.files endpoints.""" + +from __future__ import annotations + +from io import BytesIO +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, +) +from controllers.web.files import FileApi + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +class TestFileApi: + def test_no_file_uploaded(self, app: Flask) -> None: + with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"): + with pytest.raises(NoFileUploadedError): + FileApi().post(_app_model(), _end_user()) + + def test_too_many_files(self, app: Flask) -> None: + data = { + "file": (BytesIO(b"a"), "a.txt"), + "file2": (BytesIO(b"b"), "b.txt"), + } + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + # Now has "file" key but len(request.files) > 1 + with pytest.raises(TooManyFilesError): + FileApi().post(_app_model(), _end_user()) + + def test_filename_missing(self, app: Flask) -> None: + data = {"file": (BytesIO(b"content"), "")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(FilenameNotExistsError): + FileApi().post(_app_model(), _end_user()) + + @patch("controllers.web.files.FileService") + @patch("controllers.web.files.db") + def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + from datetime import datetime + + upload_file = SimpleNamespace( + id="file-1", + name="test.txt", + size=100, + extension="txt", + mime_type="text/plain", + created_by="eu-1", + created_at=datetime(2024, 1, 1), + ) + mock_file_svc_cls.return_value.upload_file.return_value = upload_file + + data = {"file": (BytesIO(b"content"), "test.txt")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + result, status = FileApi().post(_app_model(), _end_user()) + + assert status == 201 + assert result["id"] == "file-1" + assert result["name"] == "test.txt" + + @patch("controllers.web.files.FileService") + @patch("controllers.web.files.db") + def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None: + import services.errors.file + + mock_db.engine = "engine" + mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError( + description="max 10MB" + ) + + data = {"file": (BytesIO(b"big"), "big.txt")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(FileTooLargeError): + FileApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_message_endpoints.py b/api/tests/unit_tests/controllers/web/test_message_endpoints.py new file mode 100644 index 0000000000..89ab93d8d4 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_message_endpoints.py @@ -0,0 +1,156 @@ +"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + NotChatAppError, + NotCompletionAppError, +) +from controllers.web.message import ( + MessageFeedbackApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# MessageFeedbackApi +# --------------------------------------------------------------------------- +class TestMessageFeedbackApi: + @patch("controllers.web.message.MessageService.create_feedback") + @patch("controllers.web.message.web_ns") + def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": "like", "content": "great"} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + assert result == {"result": "success"} + mock_create.assert_called_once() + + @patch("controllers.web.message.MessageService.create_feedback") + @patch("controllers.web.message.web_ns") + def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": None} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + assert result == {"result": "success"} + + @patch( + "controllers.web.message.MessageService.create_feedback", + side_effect=MessageNotExistsError(), + ) + @patch("controllers.web.message.web_ns") + def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": "dislike"} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + with pytest.raises(NotFound, match="Message Not Exists"): + MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + +# --------------------------------------------------------------------------- +# MessageMoreLikeThisApi +# --------------------------------------------------------------------------- +class TestMessageMoreLikeThisApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(NotCompletionAppError): + MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id) + + @patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"}) + @patch("controllers.web.message.AppGenerateService.generate_more_like_this") + def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + msg_id = uuid4() + mock_gen.return_value = "response" + + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + assert result == {"answer": "similar"} + + @patch( + "controllers.web.message.AppGenerateService.generate_more_like_this", + side_effect=MessageNotExistsError(), + ) + def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(NotFound, match="Message Not Exists"): + MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + @patch( + "controllers.web.message.AppGenerateService.generate_more_like_this", + side_effect=MoreLikeThisDisabledError(), + ) + def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(AppMoreLikeThisDisabledError): + MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + +# --------------------------------------------------------------------------- +# MessageSuggestedQuestionApi +# --------------------------------------------------------------------------- +class TestMessageSuggestedQuestionApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotChatAppError): + MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) + + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotChatAppError): + MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) + + @patch("controllers.web.message.MessageService.get_suggested_questions_after_answer") + def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None: + msg_id = uuid4() + mock_suggest.return_value = ["What about X?", "Tell me more about Y."] + + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id) + + assert result["data"] == ["What about X?", "Tell me more about Y."] + + @patch( + "controllers.web.message.MessageService.get_suggested_questions_after_answer", + side_effect=MessageNotExistsError(), + ) + def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotFound, match="Message not found"): + MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id) diff --git a/api/tests/unit_tests/controllers/web/test_passport.py b/api/tests/unit_tests/controllers/web/test_passport.py new file mode 100644 index 0000000000..58d58626b2 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_passport.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.web.error import WebAppAuthRequiredError +from controllers.web.passport import ( + PassportService, + decode_enterprise_webapp_user_id, + exchange_token_for_existing_web_user, + generate_session_id, +) +from services.webapp_auth_service import WebAppAuthType + + +def test_decode_enterprise_webapp_user_id_none() -> None: + assert decode_enterprise_webapp_user_id(None) is None + + +def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"}) + with pytest.raises(Unauthorized): + decode_enterprise_webapp_user_id("token") + + +def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None: + decoded = {"token_source": "webapp_login_token", "user_id": "u1"} + monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded) + assert decode_enterprise_webapp_user_id("token") == decoded + + +def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + return site if _scalar_side_effect.calls == 1 else app_model + + db_session = SimpleNamespace(scalar=_scalar_side_effect) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp") + + decoded = {"auth_type": "public"} + result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC) + assert result == "resp" + + +def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + return site if _scalar_side_effect.calls == 1 else app_model + + db_session = SimpleNamespace(scalar=_scalar_side_effect) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + decoded = {"auth_type": "internal"} + with pytest.raises(WebAppAuthRequiredError): + exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL) + + +def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1") + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + if _scalar_side_effect.calls == 1: + return site + if _scalar_side_effect.calls == 2: + return app_model + return None + + db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + decoded = {"auth_type": "internal"} + with pytest.raises(NotFound): + exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL) + + +def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + counts = [1, 0] + + def _scalar(*_args, **_kwargs): + return counts.pop(0) + + db_session = SimpleNamespace(scalar=_scalar) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + session_id = generate_session_id() + assert session_id diff --git a/api/tests/unit_tests/controllers/web/test_pydantic_models.py b/api/tests/unit_tests/controllers/web/test_pydantic_models.py new file mode 100644 index 0000000000..dcf8133712 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_pydantic_models.py @@ -0,0 +1,423 @@ +"""Unit tests for Pydantic models defined in controllers.web modules. + +Covers validation logic, field defaults, constraints, and custom validators +for all ~15 Pydantic models across the web controller layer. +""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +# --------------------------------------------------------------------------- +# app.py models +# --------------------------------------------------------------------------- +from controllers.web.app import AppAccessModeQuery + + +class TestAppAccessModeQuery: + def test_alias_resolution(self) -> None: + q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"}) + assert q.app_id == "abc" + assert q.app_code == "xyz" + + def test_defaults_to_none(self) -> None: + q = AppAccessModeQuery.model_validate({}) + assert q.app_id is None + assert q.app_code is None + + def test_accepts_snake_case(self) -> None: + q = AppAccessModeQuery(app_id="id1", app_code="code1") + assert q.app_id == "id1" + assert q.app_code == "code1" + + +# --------------------------------------------------------------------------- +# audio.py models +# --------------------------------------------------------------------------- +from controllers.web.audio import TextToAudioPayload + + +class TestTextToAudioPayload: + def test_defaults(self) -> None: + p = TextToAudioPayload.model_validate({}) + assert p.message_id is None + assert p.voice is None + assert p.text is None + assert p.streaming is None + + def test_valid_uuid_message_id(self) -> None: + uid = str(uuid4()) + p = TextToAudioPayload(message_id=uid) + assert p.message_id == uid + + def test_none_message_id_passthrough(self) -> None: + p = TextToAudioPayload(message_id=None) + assert p.message_id is None + + def test_invalid_uuid_message_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + TextToAudioPayload(message_id="not-a-uuid") + + +# --------------------------------------------------------------------------- +# completion.py models +# --------------------------------------------------------------------------- +from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload + + +class TestCompletionMessagePayload: + def test_defaults(self) -> None: + p = CompletionMessagePayload(inputs={}) + assert p.query == "" + assert p.files is None + assert p.response_mode is None + assert p.retriever_from == "web_app" + + def test_accepts_full_payload(self) -> None: + p = CompletionMessagePayload( + inputs={"key": "val"}, + query="test", + files=[{"id": "f1"}], + response_mode="streaming", + ) + assert p.response_mode == "streaming" + assert p.files == [{"id": "f1"}] + + def test_invalid_response_mode(self) -> None: + with pytest.raises(ValidationError): + CompletionMessagePayload(inputs={}, response_mode="invalid") + + +class TestChatMessagePayload: + def test_valid_uuid_fields(self) -> None: + cid = str(uuid4()) + pid = str(uuid4()) + p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid) + assert p.conversation_id == cid + assert p.parent_message_id == pid + + def test_none_uuid_fields(self) -> None: + p = ChatMessagePayload(inputs={}, query="hi") + assert p.conversation_id is None + assert p.parent_message_id is None + + def test_invalid_conversation_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ChatMessagePayload(inputs={}, query="hi", conversation_id="bad") + + def test_invalid_parent_message_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad") + + def test_query_required(self) -> None: + with pytest.raises(ValidationError): + ChatMessagePayload(inputs={}) + + +# --------------------------------------------------------------------------- +# conversation.py models +# --------------------------------------------------------------------------- +from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload + + +class TestConversationListQuery: + def test_defaults(self) -> None: + q = ConversationListQuery() + assert q.last_id is None + assert q.limit == 20 + assert q.pinned is None + assert q.sort_by == "-updated_at" + + def test_limit_lower_bound(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(limit=0) + + def test_limit_upper_bound(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(limit=101) + + def test_limit_boundaries_valid(self) -> None: + assert ConversationListQuery(limit=1).limit == 1 + assert ConversationListQuery(limit=100).limit == 100 + + def test_valid_sort_by_options(self) -> None: + for opt in ("created_at", "-created_at", "updated_at", "-updated_at"): + assert ConversationListQuery(sort_by=opt).sort_by == opt + + def test_invalid_sort_by(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(sort_by="invalid") + + def test_valid_last_id(self) -> None: + uid = str(uuid4()) + assert ConversationListQuery(last_id=uid).last_id == uid + + def test_invalid_last_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ConversationListQuery(last_id="not-uuid") + + +class TestConversationRenamePayload: + def test_auto_generate_true_no_name_required(self) -> None: + p = ConversationRenamePayload(auto_generate=True) + assert p.name is None + + def test_auto_generate_false_requires_name(self) -> None: + with pytest.raises(ValidationError, match="name is required"): + ConversationRenamePayload(auto_generate=False) + + def test_auto_generate_false_blank_name_rejected(self) -> None: + with pytest.raises(ValidationError, match="name is required"): + ConversationRenamePayload(auto_generate=False, name=" ") + + def test_auto_generate_false_with_valid_name(self) -> None: + p = ConversationRenamePayload(auto_generate=False, name="My Chat") + assert p.name == "My Chat" + + def test_defaults(self) -> None: + p = ConversationRenamePayload(name="test") + assert p.auto_generate is False + assert p.name == "test" + + +# --------------------------------------------------------------------------- +# message.py models +# --------------------------------------------------------------------------- +from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery + + +class TestMessageListQuery: + def test_valid_query(self) -> None: + cid = str(uuid4()) + q = MessageListQuery(conversation_id=cid) + assert q.conversation_id == cid + assert q.first_id is None + assert q.limit == 20 + + def test_invalid_conversation_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + MessageListQuery(conversation_id="bad") + + def test_limit_bounds(self) -> None: + cid = str(uuid4()) + with pytest.raises(ValidationError): + MessageListQuery(conversation_id=cid, limit=0) + with pytest.raises(ValidationError): + MessageListQuery(conversation_id=cid, limit=101) + + def test_valid_first_id(self) -> None: + cid = str(uuid4()) + fid = str(uuid4()) + q = MessageListQuery(conversation_id=cid, first_id=fid) + assert q.first_id == fid + + def test_invalid_first_id(self) -> None: + cid = str(uuid4()) + with pytest.raises(ValidationError, match="not a valid uuid"): + MessageListQuery(conversation_id=cid, first_id="invalid") + + +class TestMessageFeedbackPayload: + def test_defaults(self) -> None: + p = MessageFeedbackPayload() + assert p.rating is None + assert p.content is None + + def test_valid_ratings(self) -> None: + assert MessageFeedbackPayload(rating="like").rating == "like" + assert MessageFeedbackPayload(rating="dislike").rating == "dislike" + + def test_invalid_rating(self) -> None: + with pytest.raises(ValidationError): + MessageFeedbackPayload(rating="neutral") + + +class TestMessageMoreLikeThisQuery: + def test_valid_modes(self) -> None: + assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking" + assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming" + + def test_invalid_mode(self) -> None: + with pytest.raises(ValidationError): + MessageMoreLikeThisQuery(response_mode="invalid") + + def test_required(self) -> None: + with pytest.raises(ValidationError): + MessageMoreLikeThisQuery() + + +# --------------------------------------------------------------------------- +# remote_files.py models +# --------------------------------------------------------------------------- +from controllers.web.remote_files import RemoteFileUploadPayload + + +class TestRemoteFileUploadPayload: + def test_valid_url(self) -> None: + p = RemoteFileUploadPayload(url="https://example.com/file.pdf") + assert str(p.url) == "https://example.com/file.pdf" + + def test_invalid_url(self) -> None: + with pytest.raises(ValidationError): + RemoteFileUploadPayload(url="not-a-url") + + def test_url_required(self) -> None: + with pytest.raises(ValidationError): + RemoteFileUploadPayload() + + +# --------------------------------------------------------------------------- +# saved_message.py models +# --------------------------------------------------------------------------- +from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery + + +class TestSavedMessageListQuery: + def test_defaults(self) -> None: + q = SavedMessageListQuery() + assert q.last_id is None + assert q.limit == 20 + + def test_limit_bounds(self) -> None: + with pytest.raises(ValidationError): + SavedMessageListQuery(limit=0) + with pytest.raises(ValidationError): + SavedMessageListQuery(limit=101) + + def test_valid_last_id(self) -> None: + uid = str(uuid4()) + q = SavedMessageListQuery(last_id=uid) + assert q.last_id == uid + + def test_empty_last_id(self) -> None: + q = SavedMessageListQuery(last_id="") + assert q.last_id == "" + + +class TestSavedMessageCreatePayload: + def test_valid_message_id(self) -> None: + uid = str(uuid4()) + p = SavedMessageCreatePayload(message_id=uid) + assert p.message_id == uid + + def test_required(self) -> None: + with pytest.raises(ValidationError): + SavedMessageCreatePayload() + + +# --------------------------------------------------------------------------- +# workflow.py models +# --------------------------------------------------------------------------- +from controllers.web.workflow import WorkflowRunPayload + + +class TestWorkflowRunPayload: + def test_defaults(self) -> None: + p = WorkflowRunPayload(inputs={}) + assert p.inputs == {} + assert p.files is None + + def test_with_files(self) -> None: + p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}]) + assert p.files == [{"id": "f1"}] + + def test_inputs_required(self) -> None: + with pytest.raises(ValidationError): + WorkflowRunPayload() + + +# --------------------------------------------------------------------------- +# forgot_password.py models +# --------------------------------------------------------------------------- +from controllers.web.forgot_password import ( + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordSendPayload, +) + + +class TestForgotPasswordSendPayload: + def test_valid_email(self) -> None: + p = ForgotPasswordSendPayload(email="user@example.com") + assert p.email == "user@example.com" + + def test_invalid_email(self) -> None: + with pytest.raises(ValidationError, match="not a valid email"): + ForgotPasswordSendPayload(email="not-an-email") + + def test_language_optional(self) -> None: + p = ForgotPasswordSendPayload(email="a@b.com") + assert p.language is None + + +class TestForgotPasswordCheckPayload: + def test_valid(self) -> None: + p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok") + assert p.email == "a@b.com" + assert p.code == "1234" + assert p.token == "tok" + + def test_empty_token_rejected(self) -> None: + with pytest.raises(ValidationError): + ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="") + + +class TestForgotPasswordResetPayload: + def test_valid_passwords(self) -> None: + p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234") + assert p.new_password == "Valid1234" + + def test_weak_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short") + + def test_letters_only_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi") + + def test_digits_only_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789") + + +# --------------------------------------------------------------------------- +# login.py models +# --------------------------------------------------------------------------- +from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload + + +class TestLoginPayload: + def test_valid(self) -> None: + p = LoginPayload(email="a@b.com", password="Valid1234") + assert p.email == "a@b.com" + + def test_invalid_email(self) -> None: + with pytest.raises(ValidationError, match="not a valid email"): + LoginPayload(email="bad", password="Valid1234") + + def test_weak_password(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + LoginPayload(email="a@b.com", password="weak") + + +class TestEmailCodeLoginSendPayload: + def test_valid(self) -> None: + p = EmailCodeLoginSendPayload(email="a@b.com") + assert p.language is None + + def test_with_language(self) -> None: + p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans") + assert p.language == "zh-Hans" + + +class TestEmailCodeLoginVerifyPayload: + def test_valid(self) -> None: + p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok") + assert p.code == "1234" + + def test_empty_token_rejected(self) -> None: + with pytest.raises(ValidationError): + EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="") diff --git a/api/tests/unit_tests/controllers/web/test_remote_files.py b/api/tests/unit_tests/controllers/web/test_remote_files.py new file mode 100644 index 0000000000..8554f440b7 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_remote_files.py @@ -0,0 +1,147 @@ +"""Unit tests for controllers.web.remote_files endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.common.errors import FileTooLargeError, RemoteFileUploadError +from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# RemoteFileInfoApi +# --------------------------------------------------------------------------- +class TestRemoteFileInfoApi: + @patch("controllers.web.remote_files.ssrf_proxy") + def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None: + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"} + mock_proxy.head.return_value = mock_resp + + with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf") + + assert result["file_type"] == "application/pdf" + assert result["file_length"] == 1024 + + @patch("controllers.web.remote_files.ssrf_proxy") + def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None: + head_resp = MagicMock() + head_resp.status_code = 405 # Method not allowed + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"} + get_resp.raise_for_status = MagicMock() + mock_proxy.head.return_value = head_resp + mock_proxy.get.return_value = get_resp + + with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt") + + assert result["file_type"] == "text/plain" + mock_proxy.get.assert_called_once() + + +# --------------------------------------------------------------------------- +# RemoteFileUploadApi +# --------------------------------------------------------------------------- +class TestRemoteFileUploadApi: + @patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url") + @patch("controllers.web.remote_files.FileService") + @patch("controllers.web.remote_files.helpers.guess_file_info_from_response") + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + @patch("controllers.web.remote_files.db") + def test_upload_success( + self, + mock_db: MagicMock, + mock_ns: MagicMock, + mock_proxy: MagicMock, + mock_guess: MagicMock, + mock_file_svc_cls: MagicMock, + mock_signed: MagicMock, + app: Flask, + ) -> None: + mock_db.engine = "engine" + mock_ns.payload = {"url": "https://example.com/file.pdf"} + head_resp = MagicMock() + head_resp.status_code = 200 + head_resp.content = b"pdf-content" + head_resp.request.method = "HEAD" + mock_proxy.head.return_value = head_resp + get_resp = MagicMock() + get_resp.content = b"pdf-content" + mock_proxy.get.return_value = get_resp + + mock_guess.return_value = SimpleNamespace( + filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100 + ) + mock_file_svc_cls.is_file_size_within_limit.return_value = True + + from datetime import datetime + + upload_file = SimpleNamespace( + id="f-1", + name="file.pdf", + size=100, + extension="pdf", + mime_type="application/pdf", + created_by="eu-1", + created_at=datetime(2024, 1, 1), + ) + mock_file_svc_cls.return_value.upload_file.return_value = upload_file + + with app.test_request_context("/remote-files/upload", method="POST"): + result, status = RemoteFileUploadApi().post(_app_model(), _end_user()) + + assert status == 201 + assert result["id"] == "f-1" + + @patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False) + @patch("controllers.web.remote_files.helpers.guess_file_info_from_response") + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + def test_file_too_large( + self, + mock_ns: MagicMock, + mock_proxy: MagicMock, + mock_guess: MagicMock, + mock_size_check: MagicMock, + app: Flask, + ) -> None: + mock_ns.payload = {"url": "https://example.com/big.zip"} + head_resp = MagicMock() + head_resp.status_code = 200 + mock_proxy.head.return_value = head_resp + mock_guess.return_value = SimpleNamespace( + filename="big.zip", extension="zip", mimetype="application/zip", size=999999999 + ) + + with app.test_request_context("/remote-files/upload", method="POST"): + with pytest.raises(FileTooLargeError): + RemoteFileUploadApi().post(_app_model(), _end_user()) + + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None: + import httpx + + mock_ns.payload = {"url": "https://example.com/bad"} + mock_proxy.head.side_effect = httpx.RequestError("connection failed") + + with app.test_request_context("/remote-files/upload", method="POST"): + with pytest.raises(RemoteFileUploadError): + RemoteFileUploadApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_saved_message.py b/api/tests/unit_tests/controllers/web/test_saved_message.py new file mode 100644 index 0000000000..3d55804912 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_saved_message.py @@ -0,0 +1,97 @@ +"""Unit tests for controllers.web.saved_message endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.error import NotCompletionAppError +from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi +from services.errors.message import MessageNotExistsError + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# SavedMessageListApi (GET) +# --------------------------------------------------------------------------- +class TestSavedMessageListApiGet: + def test_non_completion_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/saved-messages"): + with pytest.raises(NotCompletionAppError): + SavedMessageListApi().get(_chat_app(), _end_user()) + + @patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id") + def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None: + mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[]) + + with app.test_request_context("/saved-messages?limit=20"): + result = SavedMessageListApi().get(_completion_app(), _end_user()) + + assert result["limit"] == 20 + assert result["has_more"] is False + + +# --------------------------------------------------------------------------- +# SavedMessageListApi (POST) +# --------------------------------------------------------------------------- +class TestSavedMessageListApiPost: + def test_non_completion_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/saved-messages", method="POST"): + with pytest.raises(NotCompletionAppError): + SavedMessageListApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.saved_message.SavedMessageService.save") + @patch("controllers.web.saved_message.web_ns") + def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None: + msg_id = str(uuid4()) + mock_ns.payload = {"message_id": msg_id} + + with app.test_request_context("/saved-messages", method="POST"): + result = SavedMessageListApi().post(_completion_app(), _end_user()) + + assert result["result"] == "success" + + @patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError()) + @patch("controllers.web.saved_message.web_ns") + def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None: + mock_ns.payload = {"message_id": str(uuid4())} + + with app.test_request_context("/saved-messages", method="POST"): + with pytest.raises(NotFound, match="Message Not Exists"): + SavedMessageListApi().post(_completion_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# SavedMessageApi (DELETE) +# --------------------------------------------------------------------------- +class TestSavedMessageApi: + def test_non_completion_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"): + with pytest.raises(NotCompletionAppError): + SavedMessageApi().delete(_chat_app(), _end_user(), msg_id) + + @patch("controllers.web.saved_message.SavedMessageService.delete") + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"): + result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id) + + assert status == 204 + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/unit_tests/controllers/web/test_site.py new file mode 100644 index 0000000000..557bf93e9e --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_site.py @@ -0,0 +1,126 @@ +"""Unit tests for controllers.web.site endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.web.site import AppSiteApi, AppSiteInfo + + +def _tenant(*, status: str = "normal") -> SimpleNamespace: + return SimpleNamespace( + id="tenant-1", + status=status, + plan="basic", + custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, + ) + + +def _site() -> SimpleNamespace: + return SimpleNamespace( + title="Site", + icon_type="emoji", + icon="robot", + icon_background="#fff", + description="desc", + default_language="en", + chat_color_theme="light", + chat_color_theme_inverted=False, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + prompt_public=False, + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + +# --------------------------------------------------------------------------- +# AppSiteApi +# --------------------------------------------------------------------------- +class TestAppSiteApi: + @patch("controllers.web.site.FeatureService.get_features") + @patch("controllers.web.site.db") + def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + mock_features.return_value = SimpleNamespace(can_replace_logo=False) + site_obj = _site() + mock_db.session.query.return_value.where.return_value.first.return_value = site_obj + tenant = _tenant() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + result = AppSiteApi().get(app_model, end_user) + + # marshal_with serializes AppSiteInfo to a dict + assert result["app_id"] == "app-1" + assert result["plan"] == "basic" + assert result["enable_site"] is True + + @patch("controllers.web.site.db") + def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + mock_db.session.query.return_value.where.return_value.first.return_value = None + tenant = _tenant() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + with pytest.raises(Forbidden): + AppSiteApi().get(app_model, end_user) + + @patch("controllers.web.site.db") + def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + from models.account import TenantStatus + + mock_db.session.query.return_value.where.return_value.first.return_value = _site() + tenant = SimpleNamespace( + id="tenant-1", + status=TenantStatus.ARCHIVE, + plan="basic", + custom_config_dict={}, + ) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + with pytest.raises(Forbidden): + AppSiteApi().get(app_model, end_user) + + +# --------------------------------------------------------------------------- +# AppSiteInfo +# --------------------------------------------------------------------------- +class TestAppSiteInfo: + def test_basic_fields(self) -> None: + tenant = _tenant() + site_obj = _site() + info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) + + assert info.app_id == "app-1" + assert info.end_user_id == "eu-1" + assert info.enable_site is True + assert info.plan == "basic" + assert info.can_replace_logo is False + assert info.model_config is None + + @patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com")) + def test_can_replace_logo_sets_custom_config(self) -> None: + tenant = SimpleNamespace( + id="tenant-1", + plan="pro", + custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, + ) + site_obj = _site() + info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) + + assert info.can_replace_logo is True + assert info.custom_config["remove_webapp_brand"] is True + assert "webapp-logo" in info.custom_config["replace_webapp_logo"] diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index e62993e8d5..0661c02578 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi +import services.errors.account +from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi def encode_code(code: str) -> str: @@ -89,3 +90,114 @@ class TestEmailCodeLoginApi: mock_revoke_token.assert_called_once_with("token-123") mock_login.assert_called_once() mock_reset_login_rate.assert_called_once_with("user@example.com") + + +class TestLoginApi: + @patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok") + @patch("controllers.web.login.WebAppAuthService.authenticate") + def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None: + mock_auth.return_value = MagicMock() + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + response = LoginApi().post() + + assert response.get_json()["data"]["access_token"] == "access-tok" + mock_auth.assert_called_once() + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountLoginError(), + ) + def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.error import AccountBannedError + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AccountBannedError): + LoginApi().post() + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountPasswordError(), + ) + def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.auth.error import AuthenticationFailedError + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + +class TestLoginStatusApi: + @patch("controllers.web.login.extract_webapp_access_token", return_value=None) + def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None: + with app.test_request_context("/web/login/status"): + result = LoginStatusApi().get() + + assert result["logged_in"] is False + assert result["app_logged_in"] is False + + @patch("controllers.web.login.decode_jwt_token") + @patch("controllers.web.login.PassportService") + @patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False) + @patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1") + @patch("controllers.web.login.extract_webapp_access_token", return_value="tok") + def test_public_app_user_logged_in( + self, + mock_extract: MagicMock, + mock_app_id: MagicMock, + mock_perm: MagicMock, + mock_passport: MagicMock, + mock_decode: MagicMock, + app: Flask, + ) -> None: + mock_decode.return_value = (MagicMock(), MagicMock()) + + with app.test_request_context("/web/login/status?app_code=code1"): + result = LoginStatusApi().get() + + assert result["logged_in"] is True + assert result["app_logged_in"] is True + + @patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad")) + @patch("controllers.web.login.PassportService") + @patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True) + @patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1") + @patch("controllers.web.login.extract_webapp_access_token", return_value="tok") + def test_private_app_passport_fails( + self, + mock_extract: MagicMock, + mock_app_id: MagicMock, + mock_perm: MagicMock, + mock_passport_cls: MagicMock, + mock_decode: MagicMock, + app: Flask, + ) -> None: + mock_passport_cls.return_value.verify.side_effect = Exception("bad") + + with app.test_request_context("/web/login/status?app_code=code1"): + result = LoginStatusApi().get() + + assert result["logged_in"] is False + assert result["app_logged_in"] is False + + +class TestLogoutApi: + @patch("controllers.web.login.clear_webapp_access_token_from_cookie") + def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None: + with app.test_request_context("/web/logout", method="POST"): + response = LogoutApi().post() + + assert response.get_json() == {"result": "success"} + mock_clear.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_passport.py b/api/tests/unit_tests/controllers/web/test_web_passport.py new file mode 100644 index 0000000000..19b1d8504a --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_passport.py @@ -0,0 +1,192 @@ +"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.web.error import WebAppAuthRequiredError +from controllers.web.passport import ( + PassportResource, + decode_enterprise_webapp_user_id, + exchange_token_for_existing_web_user, + generate_session_id, +) +from services.webapp_auth_service import WebAppAuthType + + +# --------------------------------------------------------------------------- +# decode_enterprise_webapp_user_id +# --------------------------------------------------------------------------- +class TestDecodeEnterpriseWebappUserId: + def test_none_token_returns_none(self) -> None: + assert decode_enterprise_webapp_user_id(None) is None + + @patch("controllers.web.passport.PassportService") + def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = { + "token_source": "webapp_login_token", + "user_id": "u1", + } + result = decode_enterprise_webapp_user_id("valid-jwt") + assert result["user_id"] == "u1" + + @patch("controllers.web.passport.PassportService") + def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = { + "token_source": "other_source", + } + with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"): + decode_enterprise_webapp_user_id("bad-jwt") + + @patch("controllers.web.passport.PassportService") + def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = {} + with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"): + decode_enterprise_webapp_user_id("no-source-jwt") + + +# --------------------------------------------------------------------------- +# generate_session_id +# --------------------------------------------------------------------------- +class TestGenerateSessionId: + @patch("controllers.web.passport.db") + def test_returns_unique_session_id(self, mock_db: MagicMock) -> None: + mock_db.session.scalar.return_value = 0 + sid = generate_session_id() + assert isinstance(sid, str) + assert len(sid) == 36 # UUID format + + @patch("controllers.web.passport.db") + def test_retries_on_collision(self, mock_db: MagicMock) -> None: + # First call returns count=1 (collision), second returns 0 + mock_db.session.scalar.side_effect = [1, 0] + sid = generate_session_id() + assert isinstance(sid, str) + assert mock_db.session.scalar.call_count == 2 + + +# --------------------------------------------------------------------------- +# exchange_token_for_existing_web_user +# --------------------------------------------------------------------------- +class TestExchangeTokenForExistingWebUser: + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + site = SimpleNamespace(code="code1", app_id="app-1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + + decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external" + with pytest.raises(WebAppAuthRequiredError, match="external"): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL + ) + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + site = SimpleNamespace(code="code1", app_id="app-1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + + decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal" + with pytest.raises(WebAppAuthRequiredError, match="internal"): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL + ) + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + mock_db.session.scalar.return_value = None + decoded = {"user_id": "u1", "auth_type": "external"} + with pytest.raises(NotFound): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL + ) + + +# --------------------------------------------------------------------------- +# PassportResource.get +# --------------------------------------------------------------------------- +class TestPassportResource: + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + with app.test_request_context("/passport"): + with pytest.raises(Unauthorized, match="X-App-Code"): + PassportResource().get() + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.generate_session_id", return_value="new-sess-id") + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_creates_new_end_user_when_no_user_id( + self, + mock_features: MagicMock, + mock_db: MagicMock, + mock_gen_session: MagicMock, + mock_passport_cls: MagicMock, + app: Flask, + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + mock_passport_cls.return_value.issue.return_value = "issued-token" + + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + response = PassportResource().get() + + assert response.get_json()["access_token"] == "issued-token" + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_reuses_existing_end_user_when_user_id_provided( + self, + mock_features: MagicMock, + mock_db: MagicMock, + mock_passport_cls: MagicMock, + app: Flask, + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing") + mock_db.session.scalar.side_effect = [site, app_model, existing_user] + mock_passport_cls.return_value.issue.return_value = "reused-token" + + with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}): + response = PassportResource().get() + + assert response.get_json()["access_token"] == "reused-token" + # Should not create a new end user + mock_db.session.add.assert_not_called() + + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_db.session.scalar.return_value = None + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + PassportResource().get() + + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False) + mock_db.session.scalar.side_effect = [site, disabled_app] + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + PassportResource().get() diff --git a/api/tests/unit_tests/controllers/web/test_workflow.py b/api/tests/unit_tests/controllers/web/test_workflow.py new file mode 100644 index 0000000000..0973340527 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_workflow.py @@ -0,0 +1,95 @@ +"""Unit tests for controllers.web.workflow endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.error import ( + NotWorkflowAppError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi +from core.errors.error import ProviderTokenNotInitError, QuotaExceededError + + +def _workflow_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="workflow") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# WorkflowRunApi +# --------------------------------------------------------------------------- +class TestWorkflowRunApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(NotWorkflowAppError): + WorkflowRunApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"}) + @patch("controllers.web.workflow.AppGenerateService.generate") + @patch("controllers.web.workflow.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {"key": "val"}} + mock_gen.return_value = "response" + + with app.test_request_context("/workflows/run", method="POST"): + result = WorkflowRunApi().post(_workflow_app(), _end_user()) + + assert result == {"result": "ok"} + + @patch( + "controllers.web.workflow.AppGenerateService.generate", + side_effect=ProviderTokenNotInitError(description="not init"), + ) + @patch("controllers.web.workflow.web_ns") + def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(ProviderNotInitializeError): + WorkflowRunApi().post(_workflow_app(), _end_user()) + + @patch( + "controllers.web.workflow.AppGenerateService.generate", + side_effect=QuotaExceededError(), + ) + @patch("controllers.web.workflow.web_ns") + def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(ProviderQuotaExceededError): + WorkflowRunApi().post(_workflow_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# WorkflowTaskStopApi +# --------------------------------------------------------------------------- +class TestWorkflowTaskStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): + with pytest.raises(NotWorkflowAppError): + WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1") + + @patch("controllers.web.workflow.GraphEngineManager.send_stop_command") + @patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check") + def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None: + with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): + result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1") + + assert result == {"result": "success"} + mock_legacy.assert_called_once_with("task-1") + mock_graph.assert_called_once_with("task-1") diff --git a/api/tests/unit_tests/controllers/web/test_workflow_events.py b/api/tests/unit_tests/controllers/web/test_workflow_events.py new file mode 100644 index 0000000000..64c09b5e22 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_workflow_events.py @@ -0,0 +1,127 @@ +"""Unit tests for controllers.web.workflow_events endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.error import NotFoundError +from controllers.web.workflow_events import WorkflowEventsApi +from models.enums import CreatorUserRole + + +def _workflow_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# WorkflowEventsApi +# --------------------------------------------------------------------------- +class TestWorkflowEventsApi: + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="other-app", + created_by_role=CreatorUserRole.END_USER, + created_by="eu-1", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_not_created_by_end_user( + self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask + ) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="eu-1", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="other-user", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.WorkflowResponseConverter") + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_finished_run_returns_sse_response( + self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask + ) -> None: + from datetime import datetime + + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="eu-1", + finished_at=datetime(2024, 1, 1), + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + finish_response = MagicMock() + finish_response.model_dump.return_value = {"task_id": "run-1"} + finish_response.event.value = "workflow_finished" + mock_converter.workflow_run_result_to_finish_response.return_value = finish_response + + with app.test_request_context("/workflow/run-1/events"): + response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + assert response.mimetype == "text/event-stream" diff --git a/api/tests/unit_tests/controllers/web/test_wraps.py b/api/tests/unit_tests/controllers/web/test_wraps.py new file mode 100644 index 0000000000..85049ae975 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_wraps.py @@ -0,0 +1,393 @@ +"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import BadRequest, NotFound, Unauthorized + +from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError +from controllers.web.wraps import ( + _validate_user_accessibility, + _validate_webapp_token, + decode_jwt_token, +) + + +# --------------------------------------------------------------------------- +# _validate_webapp_token +# --------------------------------------------------------------------------- +class TestValidateWebappToken: + def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None: + """When both flags are true, a non-webapp source must raise.""" + decoded = {"token_source": "other"} + with pytest.raises(WebAppAuthRequiredError): + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None: + decoded = {"token_source": "webapp"} + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None: + decoded = {} + with pytest.raises(WebAppAuthRequiredError): + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_public_app_rejects_webapp_source(self) -> None: + """When auth is not required, a webapp-sourced token must be rejected.""" + decoded = {"token_source": "webapp"} + with pytest.raises(Unauthorized): + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_public_app_accepts_non_webapp_source(self) -> None: + decoded = {"token_source": "other"} + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_public_app_accepts_no_source(self) -> None: + decoded = {} + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_system_enabled_but_app_public(self) -> None: + """system_webapp_auth_enabled=True but app is public — webapp source rejected.""" + decoded = {"token_source": "webapp"} + with pytest.raises(Unauthorized): + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True) + + +# --------------------------------------------------------------------------- +# _validate_user_accessibility +# --------------------------------------------------------------------------- +class TestValidateUserAccessibility: + def test_skips_when_auth_disabled(self) -> None: + """No checks when system or app auth is disabled.""" + _validate_user_accessibility( + decoded={}, + app_code="code", + app_web_auth_enabled=False, + system_webapp_auth_enabled=False, + webapp_settings=None, + ) + + def test_missing_user_id_raises(self) -> None: + decoded = {} + with pytest.raises(WebAppAuthRequiredError): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=SimpleNamespace(access_mode="internal"), + ) + + def test_missing_webapp_settings_raises(self) -> None: + decoded = {"user_id": "u1"} + with pytest.raises(WebAppAuthRequiredError, match="settings not found"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=None, + ) + + def test_missing_auth_type_raises(self) -> None: + decoded = {"user_id": "u1", "granted_at": 1} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + def test_missing_granted_at_raises(self) -> None: + decoded = {"user_id": "u1", "auth_type": "external"} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_external_auth_type_checks_sso_update_time( + self, mock_perm_check: MagicMock, mock_sso_time: MagicMock + ) -> None: + # granted_at is before SSO update time → denied + mock_sso_time.return_value = datetime.now(UTC) + old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp()) + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_internal_auth_type_checks_workspace_sso_update_time( + self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock + ) -> None: + mock_workspace_sso.return_value = datetime.now(UTC) + old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp()) + decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_external_auth_passes_when_granted_after_sso_update( + self, mock_perm_check: MagicMock, mock_sso_time: MagicMock + ) -> None: + mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2) + recent_granted = int(datetime.now(UTC).timestamp()) + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted} + settings = SimpleNamespace(access_mode="public") + # Should not raise + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False) + @patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True) + def test_permission_check_denies_unauthorized_user( + self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock + ) -> None: + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())} + settings = SimpleNamespace(access_mode="internal") + with pytest.raises(WebAppAuthAccessDeniedError): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + +# --------------------------------------------------------------------------- +# decode_jwt_token +# --------------------------------------------------------------------------- +class TestDecodeJwtToken: + @patch("controllers.web.wraps._validate_user_accessibility") + @patch("controllers.web.wraps._validate_webapp_token") + @patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.wraps.AppService.get_app_id_by_code") + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_happy_path( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + mock_app_id: MagicMock, + mock_access_mode: MagicMock, + mock_validate_token: MagicMock, + mock_validate_user: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + end_user = SimpleNamespace(id="eu-1", session_id="sess-1") + + # Configure session mock to return correct objects via scalar() + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, end_user] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + result_app, result_user = decode_jwt_token() + + assert result_app.id == "app-1" + assert result_user.id == "eu-1" + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.extract_webapp_passport") + def test_missing_token_raises_unauthorized( + self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_extract.return_value = None + + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(Unauthorized): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_missing_app_raises_not_found( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + session_mock = MagicMock() + session_mock.scalar.return_value = None # No app found + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_disabled_site_raises_bad_request( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=False) + + session_mock = MagicMock() + # scalar calls: app_model, site (code found), then end_user + session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(BadRequest, match="Site is disabled"): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_missing_end_user_raises_not_found( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, None] # end_user is None + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_user_id_mismatch_raises_unauthorized( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + end_user = SimpleNamespace(id="eu-1", session_id="sess-1") + + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, end_user] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(Unauthorized, match="expired"): + decode_jwt_token(user_id="different-user") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 12ab587564..15aceef2c7 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -125,7 +125,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, @@ -265,7 +269,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, @@ -412,7 +420,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py new file mode 100644 index 0000000000..5792a2f1e2 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -0,0 +1,170 @@ +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import QueueStopEvent +from core.moderation.base import ModerationError + + +@pytest.fixture +def build_runner(): + """Construct a minimal AdvancedChatAppRunner with heavy dependencies mocked.""" + app_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Mocks for constructor args + mock_queue_manager = MagicMock() + + mock_conversation = MagicMock() + mock_conversation.id = str(uuid4()) + mock_conversation.app_id = app_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + mock_workflow = MagicMock() + mock_workflow.id = workflow_id + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + gen = MagicMock(spec=AdvancedChatAppGenerateEntity) + gen.app_config = mock_app_config + gen.inputs = {"q": "raw"} + gen.query = "raw-query" + gen.files = [] + gen.user_id = str(uuid4()) + gen.invoke_from = InvokeFrom.SERVICE_API + gen.workflow_run_id = str(uuid4()) + gen.task_id = str(uuid4()) + gen.call_depth = 0 + gen.single_iteration_run = None + gen.single_loop_run = None + gen.trace_manager = None + + runner = AdvancedChatAppRunner( + application_generate_entity=gen, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + return runner + + +def _patch_common_run_deps(runner: AdvancedChatAppRunner): + """Context manager that patches common heavy deps used by run().""" + return patch.multiple( + "core.app.apps.advanced_chat.app_runner", + Session=MagicMock( + return_value=MagicMock( + __enter__=lambda s: s, + __exit__=lambda *a, **k: False, + scalar=lambda *a, **k: MagicMock(), + ), + ), + select=MagicMock(), + db=MagicMock(engine=MagicMock()), + RedisChannel=MagicMock(), + redis_client=MagicMock(), + WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}), + GraphRuntimeState=MagicMock(), + ) + + +def test_handle_input_moderation_stops_on_moderation_error(build_runner): + runner = build_runner + + # moderation_for_inputs raises ModerationError -> should stop and emit stop event + with ( + patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("blocked")), + patch.object(runner, "_complete_with_stream_output") as mock_complete, + ): + stop, new_inputs, new_query = runner.handle_input_moderation( + app_record=MagicMock(), + app_generate_entity=runner.application_generate_entity, + inputs={"k": "v"}, + query="hello", + message_id="mid", + ) + + assert stop is True + # inputs/query should be unchanged on error path + assert new_inputs == {"k": "v"} + assert new_query == "hello" + # ensure stopped_by reason is INPUT_MODERATION + assert mock_complete.called + args, kwargs = mock_complete.call_args + assert kwargs.get("stopped_by") == QueueStopEvent.StopBy.INPUT_MODERATION + + +def test_run_applies_overridden_inputs_and_query_from_moderation(build_runner): + runner = build_runner + + overridden_inputs = {"q": "sanitized"} + overridden_query = "sanitized-query" + + with ( + _patch_common_run_deps(runner), + patch.object( + runner, + "moderation_for_inputs", + return_value=(True, overridden_inputs, overridden_query), + ) as mock_moderate, + patch.object(runner, "handle_annotation_reply", return_value=False) as mock_anno, + patch.object(runner, "_init_graph", return_value=MagicMock()) as mock_init_graph, + ): + runner.run() + + # moderation called with original values + mock_moderate.assert_called_once() + + # application_generate_entity should be updated to overridden values + assert runner.application_generate_entity.inputs == overridden_inputs + assert runner.application_generate_entity.query == overridden_query + + # annotation reply should use the new query + mock_anno.assert_called() + assert mock_anno.call_args.kwargs.get("query") == overridden_query + + # since not stopped, graph initialization should proceed + assert mock_init_graph.called + + +def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_runner): + runner = build_runner + + with ( + _patch_common_run_deps(runner), + # Simulate handle_input_moderation signalling to stop + patch.object( + runner, + "handle_input_moderation", + return_value=(True, runner.application_generate_entity.inputs, runner.application_generate_entity.query), + ) as mock_handle, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_annotation_reply") as mock_anno, + ): + runner.run() + + mock_handle.assert_called_once() + # Ensure no further steps executed + mock_anno.assert_not_called() + mock_init_graph.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py index be773557f6..83a6e0f231 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py @@ -9,8 +9,16 @@ import pytest from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent +from core.app.entities.queue_entities import ( + QueuePingEvent, + QueueTextChunkEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import StreamEvent from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import EndUser @@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None: assert message.answer == "beforeafter" assert message.status == MessageStatus.NORMAL + + +def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_id = "workflow-1" + pipeline._ensure_workflow_initialized = mock.Mock() + runtime_state = SimpleNamespace() + pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state) + pipeline._handle_advanced_chat_message_end_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)]) + ) + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace( + event=StreamEvent.WORKFLOW_FINISHED, + data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED), + ) + + event = QueueWorkflowSucceededEvent(outputs={}) + responses = list(pipeline._handle_workflow_succeeded_event(event)) + + assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED] + + +def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_id = "workflow-1" + pipeline._ensure_workflow_initialized = mock.Mock() + runtime_state = SimpleNamespace() + pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state) + pipeline._handle_advanced_chat_message_end_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)]) + ) + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace( + event=StreamEvent.WORKFLOW_FINISHED, + data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED), + ) + + event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + responses = list(pipeline._handle_workflow_partial_success_event(event)) + + assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED] + + +def test_process_stream_response_breaks_after_workflow_succeeded() -> None: + pipeline = _build_pipeline() + succeeded_event = QueueWorkflowSucceededEvent(outputs={}) + ping_event = QueuePingEvent() + queue_messages = [ + SimpleNamespace(event=succeeded_event), + SimpleNamespace(event=ping_event), + ] + + pipeline._conversation_name_generate_thread = None + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages) + pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING)) + pipeline._handle_workflow_succeeded_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)]) + ) + + responses = list(pipeline._process_stream_response()) + + assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED] + pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None) + pipeline._base_task_pipeline.ping_stream_response.assert_not_called() + + +def test_process_stream_response_breaks_after_workflow_partial_success() -> None: + pipeline = _build_pipeline() + partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + ping_event = QueuePingEvent() + queue_messages = [ + SimpleNamespace(event=partial_event), + SimpleNamespace(event=ping_event), + ] + + pipeline._conversation_name_generate_thread = None + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages) + pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING)) + pipeline._handle_workflow_partial_success_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)]) + ) + + responses = list(pipeline._process_stream_response()) + + assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED] + pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None) + pipeline._base_task_pipeline.ping_stream_response.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py index f0d9afc0db..a25e3ec3f5 100644 --- a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py @@ -124,12 +124,12 @@ def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch): def start(self): self.started = True - def fake_thread(**kwargs): + def fake_thread(*args, **kwargs): thread = DummyThread(**kwargs) captured["thread"] = thread return thread - monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread) + monkeypatch.setattr(message_cycle_manager, "Timer", fake_thread) manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock()) thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello") diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 44af89601c..e019a4b977 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -1,13 +1,8 @@ import sys import time -from pathlib import Path from types import ModuleType, SimpleNamespace from typing import Any -API_DIR = str(Path(__file__).resolve().parents[5]) -if API_DIR not in sys.path: - sys.path.insert(0, API_DIR) - import dify_graph.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py new file mode 100644 index 0000000000..582990c88a --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -0,0 +1,425 @@ +""" +Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method. + +This test suite ensures that the files array is correctly populated in the message_end +SSE event, which is critical for vision/image chat responses to render correctly. + +Test Coverage: +- Files array populated when MessageFile records exist +- Files array is None when no MessageFile records exist +- Correct signed URL generation for LOCAL_FILE transfer method +- Correct URL handling for REMOTE_URL transfer method +- Correct URL handling for TOOL_FILE transfer method +- Proper file metadata formatting (filename, mime_type, size, extension) +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.app.entities.task_entities import MessageEndStreamResponse +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from dify_graph.file.enums import FileTransferMethod +from models.model import MessageFile, UploadFile + + +class TestMessageEndStreamResponseFiles: + """Test suite for files array population in message_end SSE event.""" + + @pytest.fixture + def mock_pipeline(self): + """Create a mock EasyUIBasedGenerateTaskPipeline instance.""" + pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline) + pipeline._message_id = str(uuid.uuid4()) + pipeline._task_state = Mock() + pipeline._task_state.metadata = Mock() + pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"}) + pipeline._task_state.llm_result = Mock() + pipeline._task_state.llm_result.usage = Mock() + pipeline._application_generate_entity = Mock() + pipeline._application_generate_entity.task_id = str(uuid.uuid4()) + return pipeline + + @pytest.fixture + def mock_message_file_local(self): + """Create a mock MessageFile with LOCAL_FILE transfer method.""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + message_file.transfer_method = FileTransferMethod.LOCAL_FILE + message_file.upload_file_id = str(uuid.uuid4()) + message_file.url = None + message_file.type = "image" + return message_file + + @pytest.fixture + def mock_message_file_remote(self): + """Create a mock MessageFile with REMOTE_URL transfer method.""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + message_file.transfer_method = FileTransferMethod.REMOTE_URL + message_file.upload_file_id = None + message_file.url = "https://example.com/image.jpg" + message_file.type = "image" + return message_file + + @pytest.fixture + def mock_message_file_tool(self): + """Create a mock MessageFile with TOOL_FILE transfer method.""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + message_file.transfer_method = FileTransferMethod.TOOL_FILE + message_file.upload_file_id = None + message_file.url = "tool_file_123.png" + message_file.type = "image" + return message_file + + @pytest.fixture + def mock_upload_file(self, mock_message_file_local): + """Create a mock UploadFile.""" + upload_file = Mock(spec=UploadFile) + upload_file.id = mock_message_file_local.upload_file_id + upload_file.name = "test_image.png" + upload_file.mime_type = "image/png" + upload_file.size = 1024 + upload_file.extension = "png" + return upload_file + + def test_message_end_with_no_files(self, mock_pipeline): + """Test that files array is None when no MessageFile records exist.""" + # Arrange + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is None + assert result.id == mock_pipeline._message_id + assert result.metadata == {"test": "metadata"} + + def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file): + """Test that files array is populated correctly for LOCAL_FILE transfer method.""" + # Arrange + mock_message_file_local.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + # First query: MessageFile + mock_message_files_result = Mock() + mock_message_files_result.all.return_value = [mock_message_file_local] + + # Second query: UploadFile (batch query to avoid N+1) + mock_upload_files_result = Mock() + mock_upload_files_result.all.return_value = [mock_upload_file] + + # Setup scalars to return different results for different queries + call_count = [0] # Use list to allow modification in nested function + + def scalars_side_effect(query): + call_count[0] += 1 + # First call is for MessageFile, second call is for UploadFile + if call_count[0] == 1: + return mock_message_files_result + else: + return mock_upload_files_result + + mock_session.scalars.side_effect = scalars_side_effect + mock_get_url.return_value = "https://example.com/signed-url?signature=abc123" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert file_dict["related_id"] == mock_message_file_local.id + assert file_dict["filename"] == "test_image.png" + assert file_dict["mime_type"] == "image/png" + assert file_dict["size"] == 1024 + assert file_dict["extension"] == ".png" + assert file_dict["type"] == "image" + assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value + assert "https://example.com/signed-url" in file_dict["url"] + assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id + assert file_dict["remote_url"] == "" + + # Verify database queries + # Should be called twice: once for MessageFile, once for UploadFile + assert mock_session.scalars.call_count == 2 + mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id)) + + def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote): + """Test that files array is populated correctly for REMOTE_URL transfer method.""" + # Arrange + mock_message_file_remote.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_remote] + mock_session.scalars.return_value = mock_scalars_result + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert file_dict["related_id"] == mock_message_file_remote.id + assert file_dict["filename"] == "image.jpg" + assert file_dict["url"] == "https://example.com/image.jpg" + assert file_dict["extension"] == ".jpg" + assert file_dict["type"] == "image" + assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value + assert file_dict["remote_url"] == "https://example.com/image.jpg" + assert file_dict["upload_file_id"] == mock_message_file_remote.id + + # Verify only one query for message_files is made + mock_session.scalars.assert_called_once() + + def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool): + """Test that files array is populated correctly for TOOL_FILE with HTTP URL.""" + # Arrange + mock_message_file_tool.message_id = mock_pipeline._message_id + mock_message_file_tool.url = "https://example.com/tool_file.png" + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_tool] + mock_session.scalars.return_value = mock_scalars_result + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert file_dict["url"] == "https://example.com/tool_file.png" + assert file_dict["filename"] == "tool_file.png" + assert file_dict["extension"] == ".png" + assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value + + def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool): + """Test that files array is populated correctly for TOOL_FILE with local path.""" + # Arrange + mock_message_file_tool.message_id = mock_pipeline._message_id + mock_message_file_tool.url = "tool_file_123.png" + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_tool] + mock_session.scalars.return_value = mock_scalars_result + + mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert "https://example.com/signed-tool-file.png" in file_dict["url"] + assert file_dict["filename"] == "tool_file_123.png" + assert file_dict["extension"] == ".png" + assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value + + # Verify tool file signing was called + mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png") + + def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool): + """Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin.""" + mock_message_file_tool.message_id = mock_pipeline._message_id + mock_message_file_tool.url = "tool_file_abc.verylongextension" + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_tool] + mock_session.scalars.return_value = mock_scalars_result + mock_sign_tool.return_value = "https://example.com/signed.bin" + + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + assert result.files is not None + file_dict = result.files[0] + assert file_dict["extension"] == ".bin" + mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin") + + def test_message_end_with_multiple_files( + self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file + ): + """Test that files array contains all MessageFile records when multiple exist.""" + # Arrange + mock_message_file_local.message_id = mock_pipeline._message_id + mock_message_file_remote.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + # First query: MessageFile + mock_message_files_result = Mock() + mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote] + + # Second query: UploadFile (batch query to avoid N+1) + mock_upload_files_result = Mock() + mock_upload_files_result.all.return_value = [mock_upload_file] + + # Setup scalars to return different results for different queries + call_count = [0] # Use list to allow modification in nested function + + def scalars_side_effect(query): + call_count[0] += 1 + # First call is for MessageFile, second call is for UploadFile + if call_count[0] == 1: + return mock_message_files_result + else: + return mock_upload_files_result + + mock_session.scalars.side_effect = scalars_side_effect + mock_get_url.return_value = "https://example.com/signed-url?signature=abc123" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 2 + + # Verify both files are present + file_ids = [f["related_id"] for f in result.files] + assert mock_message_file_local.id in file_ids + assert mock_message_file_remote.id in file_ids + + def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local): + """Test fallback when UploadFile is not found for LOCAL_FILE.""" + # Arrange + mock_message_file_local.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + # First query: MessageFile + mock_message_files_result = Mock() + mock_message_files_result.all.return_value = [mock_message_file_local] + + # Second query: UploadFile (batch query) - returns empty list (not found) + mock_upload_files_result = Mock() + mock_upload_files_result.all.return_value = [] # UploadFile not found + + # Setup scalars to return different results for different queries + call_count = [0] # Use list to allow modification in nested function + + def scalars_side_effect(query): + call_count[0] += 1 + # First call is for MessageFile, second call is for UploadFile + if call_count[0] == 1: + return mock_message_files_result + else: + return mock_upload_files_result + + mock_session.scalars.side_effect = scalars_side_effect + mock_get_url.return_value = "https://example.com/fallback-url?signature=def456" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert "https://example.com/fallback-url" in file_dict["url"] + # Verify fallback URL was generated using upload_file_id from message_file + mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id)) diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py new file mode 100644 index 0000000000..5482b4db52 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +from configs import dify_config +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType + + +class ConcreteDatasourcePlugin(DatasourcePlugin): + """ + Concrete implementation of DatasourcePlugin for testing purposes. + Since DatasourcePlugin is an ABC, we need a concrete class to instantiate it. + """ + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.LOCAL_FILE + + +class TestDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + + # Act + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Act + provider_type = plugin.datasource_provider_type() + # Call the base class method to ensure it's covered + base_provider_type = DatasourcePlugin.datasource_provider_type(plugin) + + # Assert + assert provider_type == DatasourceProviderType.LOCAL_FILE + assert base_provider_type == DatasourceProviderType.LOCAL_FILE + + def test_fork_datasource_runtime(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_entity_copy = MagicMock(spec=DatasourceEntity) + mock_entity.model_copy.return_value = mock_entity_copy + + runtime = MagicMock(spec=DatasourceRuntime) + new_runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + + plugin = ConcreteDatasourcePlugin(entity=mock_entity, runtime=runtime, icon=icon) + + # Act + new_plugin = plugin.fork_datasource_runtime(new_runtime) + + # Assert + assert isinstance(new_plugin, ConcreteDatasourcePlugin) + assert new_plugin.entity == mock_entity_copy + assert new_plugin.runtime == new_runtime + assert new_plugin.icon == icon + mock_entity.model_copy.assert_called_once() + + def test_get_icon_url(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + tenant_id = "test-tenant-id" + + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Mocking dify_config.CONSOLE_API_URL + with patch.object(dify_config, "CONSOLE_API_URL", "https://api.dify.ai"): + # Act + icon_url = plugin.get_icon_url(tenant_id) + + # Assert + expected_url = ( + f"https://api.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={icon}" + ) + assert icon_url == expected_url diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py new file mode 100644 index 0000000000..6a3d21a33d --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py @@ -0,0 +1,265 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.entities.provider_entities import ProviderConfig +from core.tools.errors import ToolProviderCredentialValidationError + + +class ConcreteDatasourcePluginProviderController(DatasourcePluginProviderController): + """ + Concrete implementation of DatasourcePluginProviderController for testing purposes. + """ + + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: + return MagicMock(spec=DatasourcePlugin) + + +class TestDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + tenant_id = "test-tenant-id" + + # Act + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Assert + assert controller.entity == mock_entity + assert controller.tenant_id == tenant_id + + def test_need_credentials(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + tenant_id = "test-tenant-id" + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Case 1: credentials_schema is None + mock_entity.credentials_schema = None + assert controller.need_credentials is False + + # Case 2: credentials_schema is empty + mock_entity.credentials_schema = [] + assert controller.need_credentials is False + + # Case 3: credentials_schema has items + mock_entity.credentials_schema = [MagicMock()] + assert controller.need_credentials is True + + @patch("core.datasource.__base.datasource_provider.PluginToolManager") + def test_validate_credentials(self, mock_manager_class): + # Arrange + mock_manager = mock_manager_class.return_value + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.identity = MagicMock() + mock_entity.identity.name = "test-provider" + tenant_id = "test-tenant-id" + user_id = "test-user-id" + credentials = {"api_key": "secret"} + + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Act: Successful validation + mock_manager.validate_datasource_credentials.return_value = True + controller._validate_credentials(user_id, credentials) + + mock_manager.validate_datasource_credentials.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + provider="test-provider", + credentials=credentials, + ) + + # Act: Failed validation + mock_manager.validate_datasource_credentials.return_value = False + with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"): + controller._validate_credentials(user_id, credentials) + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.LOCAL_FILE + + def test_validate_credentials_format_empty_schema(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + credentials = {} + + # Act & Assert (Should not raise anything) + controller.validate_credentials_format(credentials) + + def test_validate_credentials_format_unknown_credential(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.identity = MagicMock() + mock_entity.identity.name = "test-provider" + mock_entity.credentials_schema = [] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + credentials = {"unknown": "value"} + + # Act & Assert + with pytest.raises( + ToolProviderCredentialValidationError, match="credential unknown not found in provider test-provider" + ): + controller.validate_credentials_format(credentials) + + def test_validate_credentials_format_required_missing(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "api_key" + mock_config.required = True + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + with pytest.raises(ToolProviderCredentialValidationError, match="credential api_key is required"): + controller.validate_credentials_format({}) + + def test_validate_credentials_format_not_required_null(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "optional" + mock_config.required = False + mock_config.default = None + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + credentials = {"optional": None} + controller.validate_credentials_format(credentials) + assert credentials["optional"] is None + + def test_validate_credentials_format_type_mismatch_text(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "text_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.TEXT_INPUT + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + with pytest.raises(ToolProviderCredentialValidationError, match="credential text_field should be string"): + controller.validate_credentials_format({"text_field": 123}) + + def test_validate_credentials_format_select_validation(self): + # Arrange + mock_option = MagicMock() + mock_option.value = "opt1" + + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "select_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.SELECT + mock_config.options = [mock_option] + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Case 1: Value not string + with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be string"): + controller.validate_credentials_format({"select_field": 123}) + + # Case 2: Options not list + mock_config.options = "invalid" + with pytest.raises( + ToolProviderCredentialValidationError, match="credential select_field options should be list" + ): + controller.validate_credentials_format({"select_field": "opt1"}) + + # Case 3: Value not in options + mock_config.options = [mock_option] + with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be one of"): + controller.validate_credentials_format({"select_field": "invalid_opt"}) + + def test_get_datasource_base(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + result = DatasourcePluginProviderController.get_datasource(controller, "test") + + # Assert + assert result is None + + def test_validate_credentials_format_hits_pop(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "valid_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.TEXT_INPUT + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {"valid_field": "valid_value"} + controller.validate_credentials_format(credentials) + + # Assert + assert "valid_field" in credentials + assert credentials["valid_field"] == "valid_value" + + def test_validate_credentials_format_hits_continue(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "optional_field" + mock_config.required = False + mock_config.default = None + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {"optional_field": None} + controller.validate_credentials_format(credentials) + + # Assert + assert credentials["optional_field"] is None + + def test_validate_credentials_format_default_values(self): + # Arrange + mock_config_text = MagicMock(spec=ProviderConfig) + mock_config_text.name = "text_def" + mock_config_text.required = False + mock_config_text.type = ProviderConfig.Type.TEXT_INPUT + mock_config_text.default = 123 # Int default, should be converted to str + + mock_config_other = MagicMock(spec=ProviderConfig) + mock_config_other.name = "other_def" + mock_config_other.required = False + mock_config_other.type = "OTHER" + mock_config_other.default = "fallback" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config_text, mock_config_other] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {} + controller.validate_credentials_format(credentials) + + # Assert + assert credentials["text_def"] == "123" + assert credentials["other_def"] == "fallback" diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py new file mode 100644 index 0000000000..2bca9155e9 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py @@ -0,0 +1,26 @@ +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.__base.datasource_runtime import DatasourceRuntime, FakeDatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceInvokeFrom + + +class TestDatasourceRuntime: + def test_init(self): + runtime = DatasourceRuntime( + tenant_id="test-tenant", + datasource_id="test-ds", + invoke_from=InvokeFrom.DEBUGGER, + datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, + credentials={"key": "val"}, + runtime_parameters={"p": "v"}, + ) + assert runtime.tenant_id == "test-tenant" + assert runtime.datasource_id == "test-ds" + assert runtime.credentials["key"] == "val" + + def test_fake_datasource_runtime(self): + # This covers the FakeDatasourceRuntime class and its __init__ + runtime = FakeDatasourceRuntime() + assert runtime.tenant_id == "fake_tenant_id" + assert runtime.datasource_id == "fake_datasource_id" + assert runtime.invoke_from == InvokeFrom.DEBUGGER + assert runtime.datasource_invoke_from == DatasourceInvokeFrom.RAG_PIPELINE diff --git a/api/tests/unit_tests/core/datasource/entities/test_api_entities.py b/api/tests/unit_tests/core/datasource/entities/test_api_entities.py new file mode 100644 index 0000000000..9855b4040a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_api_entities.py @@ -0,0 +1,150 @@ +from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity +from core.datasource.entities.datasource_entities import DatasourceParameter +from core.tools.entities.common_entities import I18nObject + + +def test_datasource_api_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + + entity = DatasourceApiEntity( + author="author", name="name", label=label, description=description, labels=["l1", "l2"] + ) + + assert entity.author == "author" + assert entity.name == "name" + assert entity.label == label + assert entity.description == description + assert entity.labels == ["l1", "l2"] + assert entity.parameters is None + assert entity.output_schema is None + + +def test_datasource_provider_api_entity_defaults(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + entity = DatasourceProviderApiEntity( + id="id", author="author", name="name", description=description, icon="icon", label=label, type="type" + ) + + assert entity.id == "id" + assert entity.datasources == [] + assert entity.is_team_authorization is False + assert entity.allow_delete is True + assert entity.plugin_id == "" + assert entity.plugin_unique_identifier == "" + assert entity.labels == [] + + +def test_datasource_provider_api_entity_convert_none_to_empty_list(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + # Implicitly testing the field_validator "convert_none_to_empty_list" + entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=None, # type: ignore + ) + + assert entity.datasources == [] + + +def test_datasource_provider_api_entity_to_dict(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + # Create a parameter that should be converted + param = DatasourceParameter.get_simple_instance( + name="test_param", typ=DatasourceParameter.DatasourceParameterType.SYSTEM_FILES, required=True + ) + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=[param] + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + masked_credentials={"key": "masked"}, + datasources=[ds_entity], + labels=["l1"], + ) + + result = provider_entity.to_dict() + + assert result["id"] == "id" + assert result["author"] == "author" + assert result["name"] == "name" + assert result["description"] == description.to_dict() + assert result["icon"] == "icon" + assert result["label"] == label.to_dict() + assert result["type"] == "type" + assert result["team_credentials"] == {"key": "masked"} + assert result["is_team_authorization"] is False + assert result["allow_delete"] is True + assert result["labels"] == ["l1"] + + # Check if parameter type was converted from SYSTEM_FILES to files + assert result["datasources"][0]["parameters"][0]["type"] == "files" + + +def test_datasource_provider_api_entity_to_dict_no_params(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=None + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=[ds_entity], + ) + + result = provider_entity.to_dict() + assert result["datasources"][0]["parameters"] is None + + +def test_datasource_provider_api_entity_to_dict_other_param_type(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + param = DatasourceParameter.get_simple_instance( + name="test_param", typ=DatasourceParameter.DatasourceParameterType.STRING, required=True + ) + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=[param] + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=[ds_entity], + ) + + result = provider_entity.to_dict() + assert result["datasources"][0]["parameters"][0]["type"] == "string" diff --git a/api/tests/unit_tests/core/datasource/entities/test_common_entities.py b/api/tests/unit_tests/core/datasource/entities/test_common_entities.py new file mode 100644 index 0000000000..0ee4928105 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_common_entities.py @@ -0,0 +1,31 @@ +from core.datasource.entities.common_entities import I18nObject + + +def test_i18n_object_fallback(): + # Only en_US provided + obj = I18nObject(en_US="Hello") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "Hello" + assert obj.pt_BR == "Hello" + assert obj.ja_JP == "Hello" + + # Some fields provided + obj = I18nObject(en_US="Hello", zh_Hans="你好") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "你好" + assert obj.pt_BR == "Hello" + assert obj.ja_JP == "Hello" + + +def test_i18n_object_all_fields(): + obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "你好" + assert obj.pt_BR == "Olá" + assert obj.ja_JP == "こんにちは" + + +def test_i18n_object_to_dict(): + obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは") + expected_dict = {"en_US": "Hello", "zh_Hans": "你好", "pt_BR": "Olá", "ja_JP": "こんにちは"} + assert obj.to_dict() == expected_dict diff --git a/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py b/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py new file mode 100644 index 0000000000..a8c8d31537 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py @@ -0,0 +1,275 @@ +from unittest.mock import patch + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceInvokeMeta, + DatasourceLabel, + DatasourceMessage, + DatasourceParameter, + DatasourceProviderEntity, + DatasourceProviderEntityWithPlugin, + DatasourceProviderIdentity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetWebsiteCrawlRequest, + OnlineDocumentInfo, + OnlineDocumentPage, + OnlineDocumentPageContent, + OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, + OnlineDriveDownloadFileRequest, + OnlineDriveFile, + OnlineDriveFileBucket, + WebsiteCrawlMessage, + WebSiteInfo, + WebSiteInfoDetail, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolLabelEnum + + +def test_datasource_provider_type(): + assert DatasourceProviderType.value_of("online_document") == DatasourceProviderType.ONLINE_DOCUMENT + assert DatasourceProviderType.value_of("local_file") == DatasourceProviderType.LOCAL_FILE + + with pytest.raises(ValueError, match="invalid mode value invalid"): + DatasourceProviderType.value_of("invalid") + + +def test_datasource_parameter_type(): + param_type = DatasourceParameter.DatasourceParameterType.STRING + assert param_type.as_normal_type() == "string" + assert param_type.cast_value("test") == "test" + + param_type = DatasourceParameter.DatasourceParameterType.NUMBER + assert param_type.cast_value("123") == 123 + + +def test_datasource_parameter(): + param = DatasourceParameter.get_simple_instance( + name="test_param", + typ=DatasourceParameter.DatasourceParameterType.STRING, + required=True, + options=["opt1", "opt2"], + ) + assert param.name == "test_param" + assert param.type == DatasourceParameter.DatasourceParameterType.STRING + assert param.required is True + assert len(param.options) == 2 + assert param.options[0].value == "opt1" + + param_no_options = DatasourceParameter.get_simple_instance( + name="test_param_2", typ=DatasourceParameter.DatasourceParameterType.NUMBER, required=False + ) + assert param_no_options.options == [] + + # Test init_frontend_parameter + # For STRING, it should just return the value as is (or cast to str) + frontend_param = param.init_frontend_parameter("val") + assert frontend_param == "val" + + # Test parameter type methods + assert DatasourceParameter.DatasourceParameterType.STRING.as_normal_type() == "string" + assert DatasourceParameter.DatasourceParameterType.NUMBER.as_normal_type() == "number" + assert DatasourceParameter.DatasourceParameterType.SECRET_INPUT.as_normal_type() == "string" + + assert DatasourceParameter.DatasourceParameterType.NUMBER.cast_value("10.5") == 10.5 + assert DatasourceParameter.DatasourceParameterType.BOOLEAN.cast_value("true") is True + assert DatasourceParameter.DatasourceParameterType.FILES.cast_value(["f1", "f2"]) == ["f1", "f2"] + + +def test_datasource_identity(): + label = I18nObject(en_US="label", zh_Hans="标签") + identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider", icon="icon") + assert identity.author == "author" + assert identity.name == "name" + assert identity.label == label + assert identity.provider == "provider" + assert identity.icon == "icon" + + +def test_datasource_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider") + description = I18nObject(en_US="desc", zh_Hans="描述") + + entity = DatasourceEntity( + identity=identity, + description=description, + parameters=None, # Should be handled by validator + ) + assert entity.parameters == [] + + param = DatasourceParameter.get_simple_instance("p1", DatasourceParameter.DatasourceParameterType.STRING, True) + entity_with_params = DatasourceEntity(identity=identity, description=description, parameters=[param]) + assert entity_with_params.parameters == [param] + + +def test_datasource_provider_identity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon.png", label=label, tags=[ToolLabelEnum.SEARCH] + ) + + assert identity.author == "author" + assert identity.name == "name" + assert identity.description == description + assert identity.icon == "icon.png" + assert identity.label == label + assert identity.tags == [ToolLabelEnum.SEARCH] + + # Test generate_datasource_icon_url + with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config: + mock_config.CONSOLE_API_URL = "http://api.example.com" + url = identity.generate_datasource_icon_url("tenant123") + assert "http://api.example.com/console/api/workspaces/current/plugin/icon" in url + assert "tenant_id=tenant123" in url + assert "filename=icon.png" in url + + # Test hardcoded icon + identity.icon = "https://assets.dify.ai/images/File%20Upload.svg" + assert identity.generate_datasource_icon_url("tenant123") == identity.icon + + # Test with empty CONSOLE_API_URL + identity.icon = "test.png" + with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config: + mock_config.CONSOLE_API_URL = None + url = identity.generate_datasource_icon_url("tenant123") + assert url.startswith("/console/api/workspaces/current/plugin/icon") + + +def test_datasource_provider_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon", label=label + ) + + entity = DatasourceProviderEntity( + identity=identity, + provider_type=DatasourceProviderType.ONLINE_DOCUMENT, + credentials_schema=[], + oauth_schema=None, + ) + assert entity.identity == identity + assert entity.provider_type == DatasourceProviderType.ONLINE_DOCUMENT + assert entity.credentials_schema == [] + + +def test_datasource_provider_entity_with_plugin(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon", label=label + ) + + entity = DatasourceProviderEntityWithPlugin( + identity=identity, provider_type=DatasourceProviderType.ONLINE_DOCUMENT, datasources=[] + ) + assert entity.datasources == [] + + +def test_datasource_invoke_meta(): + meta = DatasourceInvokeMeta(time_cost=1.5, error="some error", tool_config={"k": "v"}) + assert meta.time_cost == 1.5 + assert meta.error == "some error" + assert meta.tool_config == {"k": "v"} + + d = meta.to_dict() + assert d == {"time_cost": 1.5, "error": "some error", "tool_config": {"k": "v"}} + + empty_meta = DatasourceInvokeMeta.empty() + assert empty_meta.time_cost == 0.0 + assert empty_meta.error is None + assert empty_meta.tool_config == {} + + error_meta = DatasourceInvokeMeta.error_instance("fatal error") + assert error_meta.time_cost == 0.0 + assert error_meta.error == "fatal error" + assert error_meta.tool_config == {} + + +def test_datasource_label(): + label_obj = I18nObject(en_US="label", zh_Hans="标签") + ds_label = DatasourceLabel(name="name", label=label_obj, icon="icon") + assert ds_label.name == "name" + assert ds_label.label == label_obj + assert ds_label.icon == "icon" + + +def test_online_document_models(): + page = OnlineDocumentPage( + page_id="p1", + page_name="name", + page_icon={"type": "emoji"}, + type="page", + last_edited_time="2023-01-01", + parent_id=None, + ) + assert page.page_id == "p1" + + info = OnlineDocumentInfo(workspace_id="w1", workspace_name="name", workspace_icon="icon", total=1, pages=[page]) + assert info.total == 1 + + msg = OnlineDocumentPagesMessage(result=[info]) + assert msg.result == [info] + + req = GetOnlineDocumentPageContentRequest(workspace_id="w1", page_id="p1", type="page") + assert req.workspace_id == "w1" + + content = OnlineDocumentPageContent(workspace_id="w1", page_id="p1", content="hello") + assert content.content == "hello" + + resp = GetOnlineDocumentPageContentResponse(result=content) + assert resp.result == content + + +def test_website_crawl_models(): + req = GetWebsiteCrawlRequest(crawl_parameters={"url": "http://test.com"}) + assert req.crawl_parameters == {"url": "http://test.com"} + + detail = WebSiteInfoDetail(source_url="http://test.com", content="content", title="title", description="desc") + assert detail.title == "title" + + info = WebSiteInfo(status="completed", web_info_list=[detail], total=1, completed=1) + assert info.status == "completed" + + msg = WebsiteCrawlMessage(result=info) + assert msg.result == info + + # Test default values + msg_default = WebsiteCrawlMessage() + assert msg_default.result.status == "" + assert msg_default.result.web_info_list == [] + + +def test_online_drive_models(): + file = OnlineDriveFile(id="f1", name="file.txt", size=100, type="file") + assert file.name == "file.txt" + + bucket = OnlineDriveFileBucket(bucket="b1", files=[file], is_truncated=False, next_page_parameters=None) + assert bucket.bucket == "b1" + + req = OnlineDriveBrowseFilesRequest(bucket="b1", prefix="folder1", max_keys=10, next_page_parameters=None) + assert req.prefix == "folder1" + + resp = OnlineDriveBrowseFilesResponse(result=[bucket]) + assert resp.result == [bucket] + + dl_req = OnlineDriveDownloadFileRequest(id="f1", bucket="b1") + assert dl_req.id == "f1" + + +def test_datasource_message(): + # Use proper dict for message to avoid Pydantic Union validation ambiguity/crashes + msg = DatasourceMessage(type="text", message={"text": "hello"}) + assert msg.message.text == "hello" + + msg_json = DatasourceMessage(type="json", message={"json_object": {"k": "v"}}) + assert msg_json.message.json_object == {"k": "v"} diff --git a/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py b/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py new file mode 100644 index 0000000000..5bf7362a8a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py @@ -0,0 +1,57 @@ +from unittest.mock import MagicMock + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, +) +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin + + +class TestLocalFileDatasourcePlugin: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test-tenant-id" + icon = "test-icon" + plugin_unique_identifier = "test-plugin-id" + + # Act + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.tenant_id == tenant_id + assert plugin.plugin_unique_identifier == plugin_unique_identifier + assert plugin.entity == mock_entity + assert plugin.runtime == mock_runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.datasource_provider_type() == DatasourceProviderType.LOCAL_FILE + + def test_get_icon_url(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon" + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon=icon, plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.get_icon_url("any-tenant-id") == icon diff --git a/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py b/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py new file mode 100644 index 0000000000..af2369ac4e --- /dev/null +++ b/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py @@ -0,0 +1,96 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin +from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController + + +class TestLocalFileDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + # Act + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.LOCAL_FILE + + def test_validate_credentials(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + # Should not raise any exception + controller._validate_credentials("user_id", {"key": "value"}) + + def test_get_datasource_success(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "test_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + datasource = controller.get_datasource("test_datasource") + + # Assert + assert isinstance(datasource, LocalFileDatasourcePlugin) + assert datasource.entity == mock_datasource_entity + assert datasource.tenant_id == tenant_id + assert datasource.icon == "test_icon" + assert datasource.plugin_unique_identifier == plugin_unique_identifier + + def test_get_datasource_not_found(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name test_datasource not found"): + controller.get_datasource("test_datasource") diff --git a/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py b/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py new file mode 100644 index 0000000000..e3a217725a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py @@ -0,0 +1,151 @@ +from unittest.mock import MagicMock, patch + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin + + +class TestOnlineDocumentDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + # Act + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.tenant_id == tenant_id + assert plugin.icon == icon + assert plugin.plugin_unique_identifier == plugin_unique_identifier + + def test_get_online_document_pages(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test_key"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + datasource_parameters = {"param": "value"} + provider_type = "test_type" + + mock_generator = MagicMock() + + # Patch PluginDatasourceManager to isolate plugin behavior from external dependencies + with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.get_online_document_pages.return_value = mock_generator + + # Act + result = plugin.get_online_document_pages( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert result == mock_generator + mock_manager_instance.get_online_document_pages.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def test_get_online_document_page_content(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test_key"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + datasource_parameters = MagicMock(spec=GetOnlineDocumentPageContentRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.get_online_document_page_content.return_value = mock_generator + + # Act + result = plugin.get_online_document_page_content( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert result == mock_generator + mock_manager_instance.get_online_document_page_content.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act + result = plugin.datasource_provider_type() + + # Assert + assert result == DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py b/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py new file mode 100644 index 0000000000..cfdd05e0b2 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController + + +class TestOnlineDocumentDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_uid" + tenant_id = "test_tenant_id" + + # Act + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Assert + assert controller.provider_type == DatasourceProviderType.ONLINE_DOCUMENT + + def test_get_datasource_success(self): + # Arrange + from core.datasource.entities.datasource_entities import DatasourceIdentity + + mock_datasource_entity = MagicMock(spec=DatasourceEntity) + mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity) + mock_datasource_entity.identity.name = "target_datasource" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity = MagicMock() + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_uid" + tenant_id = "test_tenant_id" + + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id="test_plugin_id", + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Act + result = controller.get_datasource("target_datasource") + + # Assert + assert isinstance(result, OnlineDocumentDatasourcePlugin) + assert result.entity == mock_datasource_entity + assert result.tenant_id == tenant_id + assert result.icon == "test_icon" + assert result.plugin_unique_identifier == plugin_unique_identifier + assert result.runtime.tenant_id == tenant_id + + def test_get_datasource_not_found(self): + # Arrange + from core.datasource.entities.datasource_entities import DatasourceIdentity + + mock_datasource_entity = MagicMock(spec=DatasourceEntity) + mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity) + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.datasources = [mock_datasource_entity] + + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id="test_plugin_id", + plugin_unique_identifier="test_plugin_uid", + tenant_id="test_tenant_id", + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name missing_datasource not found"): + controller.get_datasource("missing_datasource") diff --git a/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py new file mode 100644 index 0000000000..6c8b644871 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py @@ -0,0 +1,147 @@ +from unittest.mock import MagicMock, patch + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceProviderType, + OnlineDriveBrowseFilesRequest, + OnlineDriveDownloadFileRequest, +) +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin + + +class TestOnlineDriveDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + # Act + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.tenant_id == tenant_id + assert plugin.icon == icon + assert plugin.plugin_unique_identifier == plugin_unique_identifier + + def test_online_drive_browse_files(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"token": "test_token"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + request = MagicMock(spec=OnlineDriveBrowseFilesRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.online_drive_browse_files.return_value = mock_generator + + # Act + result = plugin.online_drive_browse_files(user_id=user_id, request=request, provider_type=provider_type) + + # Assert + assert result == mock_generator + mock_manager_instance.online_drive_browse_files.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def test_online_drive_download_file(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"token": "test_token"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + request = MagicMock(spec=OnlineDriveDownloadFileRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.online_drive_download_file.return_value = mock_generator + + # Act + result = plugin.online_drive_download_file(user_id=user_id, request=request, provider_type=provider_type) + + # Assert + assert result == mock_generator + mock_manager_instance.online_drive_download_file.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + plugin = OnlineDriveDatasourcePlugin( + entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act + result = plugin.datasource_provider_type() + + # Assert + assert result == DatasourceProviderType.ONLINE_DRIVE diff --git a/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py new file mode 100644 index 0000000000..2824ddd8ed --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin +from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController + + +class TestOnlineDriveDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + # Act + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.ONLINE_DRIVE + + def test_get_datasource_success(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "test_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + datasource = controller.get_datasource("test_datasource") + + # Assert + assert isinstance(datasource, OnlineDriveDatasourcePlugin) + assert datasource.entity == mock_datasource_entity + assert datasource.tenant_id == tenant_id + assert datasource.icon == "test_icon" + assert datasource.plugin_unique_identifier == plugin_unique_identifier + assert datasource.runtime.tenant_id == tenant_id + + def test_get_datasource_not_found(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name test_datasource not found"): + controller.get_datasource("test_datasource") diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py new file mode 100644 index 0000000000..a7c93242cd --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -0,0 +1,409 @@ +import base64 +import hashlib +import hmac +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from core.datasource.datasource_file_manager import DatasourceFileManager +from models.model import MessageFile, UploadFile +from models.tools import ToolFile + + +class TestDatasourceFileManager: + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.os.urandom") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_sign_file(self, mock_config, mock_urandom, mock_time): + # Setup + mock_config.FILES_URL = "http://localhost:5001" + mock_config.SECRET_KEY = "test_secret" + mock_time.return_value = 1700000000 + mock_urandom.return_value = b"1234567890abcdef" # 16 bytes + + datasource_file_id = "file_id_123" + extension = ".png" + + # Execute + signed_url = DatasourceFileManager.sign_file(datasource_file_id, extension) + + # Verify + assert signed_url.startswith("http://localhost:5001/files/datasources/file_id_123.png?") + assert "timestamp=1700000000" in signed_url + assert f"nonce={mock_urandom.return_value.hex()}" in signed_url + assert "sign=" in signed_url + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.os.urandom") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_sign_file_empty_secret(self, mock_config, mock_urandom, mock_time): + # Setup + mock_config.FILES_URL = "http://localhost:5001" + mock_config.SECRET_KEY = None # Empty secret + mock_time.return_value = 1700000000 + mock_urandom.return_value = b"1234567890abcdef" + + # Execute + signed_url = DatasourceFileManager.sign_file("file_id", ".png") + assert "sign=" in signed_url + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_verify_file(self, mock_config, mock_time): + # Setup + mock_config.SECRET_KEY = "test_secret" + mock_config.FILES_ACCESS_TIMEOUT = 300 + mock_time.return_value = 1700000000 + + datasource_file_id = "file_id_123" + timestamp = "1699999800" # 200 seconds ago + nonce = "some_nonce" + + # Manually calculate sign + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = b"test_secret" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + # Execute & Verify Success + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True + + # Verify Failure - Wrong Sign + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False + + # Verify Failure - Timeout + mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout) + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_verify_file_empty_secret(self, mock_config, mock_time): + # Setup + mock_config.SECRET_KEY = "" # Empty string secret + mock_config.FILES_ACCESS_TIMEOUT = 300 + mock_time.return_value = 1700000000 + + datasource_file_id = "file_id_123" + timestamp = "1699999800" + nonce = "some_nonce" + + # Calculate with empty secret + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + sign = hmac.new(b"", data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + user_id = "user_123" + tenant_id = "tenant_456" + file_binary = b"fake binary data" + mimetype = "image/png" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mimetype, + filename="test.png", + ) + + # Verify + assert upload_file.tenant_id == tenant_id + assert upload_file.name == "test.png" + assert upload_file.size == len(file_binary) + assert upload_file.mime_type == mimetype + assert upload_file.key == f"datasources/{tenant_id}/unique_hex.png" + + mock_storage.save.assert_called_once_with(upload_file.key, file_binary) + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw_filename_no_extension(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + user_id = "user_123" + tenant_id = "tenant_456" + file_binary = b"fake binary data" + mimetype = "image/png" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mimetype, + filename="test", # No extension + ) + + # Verify + assert upload_file.name == "test.png" # Should append extension + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + @patch("core.datasource.datasource_file_manager.guess_extension") + def test_create_file_by_raw_unknown_extension(self, mock_guess_ext, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_guess_ext.return_value = None # Cannot guess + mock_uuid.return_value = MagicMock(hex="unique_hex") + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id="user", + tenant_id="tenant", + conversation_id=None, + file_binary=b"data", + mimetype="application/x-unknown", + ) + + # Verify + assert upload_file.extension == ".bin" + assert upload_file.name == "unique_hex.bin" + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw_no_filename(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id="user_123", + tenant_id="tenant_456", + conversation_id=None, + file_binary=b"data", + mimetype="application/pdf", + ) + + # Verify + assert upload_file.name == "unique_hex.pdf" + assert upload_file.extension == ".pdf" + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_mimetype_from_guess(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"bits" + mock_response.headers = {} # No content-type in headers + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.png" + ) + + # Verify + assert tool_file.mimetype == "image/png" # Guessed from .png in URL + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_mimetype_default(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"bits" + mock_response.headers = {} + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", + tenant_id="tenant_456", + file_url="https://example.com/unknown", # No extension, no headers + ) + + # Verify + assert tool_file.mimetype == "application/octet-stream" + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_success(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"downloaded bits" + mock_response.headers = {"Content-Type": "image/jpeg"} + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.jpg" + ) + + # Verify + assert tool_file.mimetype == "image/jpeg" + assert tool_file.size == len(b"downloaded bits") + assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg" + mock_storage.save.assert_called_once() + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + def test_create_file_by_url_timeout(self, mock_ssrf): + # Setup + mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout") + + # Execute & Verify + with pytest.raises(ValueError, match="timeout when downloading file"): + DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/large.file" + ) + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary(self, mock_storage, mock_db): + # Setup + mock_upload_file = MagicMock(spec=UploadFile) + mock_upload_file.key = "some_key" + mock_upload_file.mime_type = "image/png" + + mock_query = mock_db.session.query.return_value + mock_where = mock_query.where.return_value + mock_where.first.return_value = mock_upload_file + + mock_storage.load_once.return_value = b"file content" + + # Execute + result = DatasourceFileManager.get_file_binary("file_id") + + # Verify + assert result == (b"file content", "image/png") + + # Case: Not found + mock_where.first.return_value = None + assert DatasourceFileManager.get_file_binary("unknown") is None + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id(self, mock_storage, mock_db): + # Setup + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = "http://localhost/files/tools/tool_id.png" + + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.file_key = "tool_key" + mock_tool_file.mimetype = "image/png" + + # Mock query sequence + def mock_query(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + elif model == ToolFile: + m.where.return_value.first.return_value = mock_tool_file + return m + + mock_db.session.query.side_effect = mock_query + mock_storage.load_once.return_value = b"tool content" + + # Execute + result = DatasourceFileManager.get_file_binary_by_message_file_id("msg_file_id") + + # Verify + assert result == (b"tool content", "image/png") + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id_with_extension(self, mock_storage, mock_db): + # Test that it correctly parses tool_id even with extension in URL + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = "http://localhost/files/tools/abcdef.png" + + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "abcdef" + mock_tool_file.file_key = "tk" + mock_tool_file.mimetype = "image/png" + + def mock_query(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + else: + m.where.return_value.first.return_value = mock_tool_file + return m + + mock_db.session.query.side_effect = mock_query + mock_storage.load_once.return_value = b"bits" + + result = DatasourceFileManager.get_file_binary_by_message_file_id("m") + assert result == (b"bits", "image/png") + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db): + # Setup common mock + mock_query_obj = MagicMock() + mock_db.session.query.return_value = mock_query_obj + mock_query_obj.where.return_value.first.return_value = None + + # Case 1: Message file not found + assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None + + # Case 2: Message file found but tool file not found + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = None + + def mock_query_v2(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + else: + m.where.return_value.first.return_value = None + return m + + mock_db.session.query.side_effect = mock_query_v2 + assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_generator_by_upload_file_id(self, mock_storage, mock_db): + # Setup + mock_upload_file = MagicMock(spec=UploadFile) + mock_upload_file.key = "upload_key" + mock_upload_file.mime_type = "text/plain" + + mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file + + mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"]) + + # Execute + stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("upload_id") + + # Verify + assert mimetype == "text/plain" + assert list(stream) == [b"chunk1", b"chunk2"] + + # Case: Not found + mock_db.session.query.return_value.where.return_value.first.return_value = None + stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none") + assert stream is None + assert mimetype is None diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 52c91fb8c9..d5eeae912c 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -1,9 +1,15 @@ import types from collections.abc import Generator +import pytest + +from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceMessage +from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType +from core.datasource.errors import DatasourceProviderNotFoundError from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent @@ -15,6 +21,22 @@ def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, Non ) +def _drain_generator(gen: Generator[DatasourceMessage, None, object]) -> tuple[list[DatasourceMessage], object | None]: + messages: list[DatasourceMessage] = [] + try: + while True: + messages.append(next(gen)) + except StopIteration as e: + return messages, e.value + + +def _invalidate_recyclable_contextvars() -> None: + """ + Ensure RecyclableContextVar.get() raises LookupError until reset by code under test. + """ + RecyclableContextVar.increment_thread_recycles() + + def test_get_icon_url_calls_runtime(mocker): fake_runtime = mocker.Mock() fake_runtime.get_icon_url.return_value = "https://icon" @@ -30,6 +52,119 @@ def test_get_icon_url_calls_runtime(mocker): DatasourceManager.get_datasource_runtime.assert_called_once() +def test_get_datasource_runtime_delegates_to_provider_controller(mocker): + provider_controller = mocker.Mock() + provider_controller.get_datasource.return_value = object() + mocker.patch.object(DatasourceManager, "get_datasource_plugin_provider", return_value=provider_controller) + + runtime = DatasourceManager.get_datasource_runtime( + provider_id="prov/x", + datasource_name="ds", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + assert runtime is provider_controller.get_datasource.return_value + provider_controller.get_datasource.assert_called_once_with("ds") + + +@pytest.mark.parametrize( + ("datasource_type", "controller_path"), + [ + ( + DatasourceProviderType.ONLINE_DOCUMENT, + "core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.ONLINE_DRIVE, + "core.datasource.datasource_manager.OnlineDriveDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.WEBSITE_CRAWL, + "core.datasource.datasource_manager.WebsiteCrawlDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.LOCAL_FILE, + "core.datasource.datasource_manager.LocalFileDatasourcePluginProviderController", + ), + ], +) +def test_get_datasource_plugin_provider_creates_controller_and_caches(mocker, datasource_type, controller_path): + _invalidate_recyclable_contextvars() + + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + fetch = mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + ctrl_cls = mocker.patch(controller_path) + + first = DatasourceManager.get_datasource_plugin_provider( + provider_id=f"prov/{datasource_type.value}", + tenant_id="t1", + datasource_type=datasource_type, + ) + second = DatasourceManager.get_datasource_plugin_provider( + provider_id=f"prov/{datasource_type.value}", + tenant_id="t1", + datasource_type=datasource_type, + ) + + assert first is second + assert fetch.call_count == 1 + assert ctrl_cls.call_count == 1 + + +def test_get_datasource_plugin_provider_raises_when_provider_entity_missing(mocker): + _invalidate_recyclable_contextvars() + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=None, + ) + + with pytest.raises(DatasourceProviderNotFoundError, match="plugin provider prov/notfound not found"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/notfound", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + + +def test_get_datasource_plugin_provider_raises_for_unsupported_type(mocker): + _invalidate_recyclable_contextvars() + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + + with pytest.raises(ValueError, match="Unsupported datasource type"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/x", + tenant_id="t1", + datasource_type=types.SimpleNamespace(), # not a DatasourceProviderType at runtime + ) + + +def test_get_datasource_plugin_provider_raises_when_controller_none(mocker): + _invalidate_recyclable_contextvars() + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + mocker.patch( + "core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController", + return_value=None, + ) + + with pytest.raises(DatasourceProviderNotFoundError, match="Datasource provider prov/x not found"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/x", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + + def test_stream_online_results_yields_messages_online_document(mocker): # stub runtime to yield a text message def _doc_messages(**_): @@ -60,6 +195,148 @@ def test_stream_online_results_yields_messages_online_document(mocker): assert msgs[0].message.text == "hello" +def test_stream_online_results_sets_credentials_and_returns_empty_dict_online_document(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def get_online_document_page_content(self, **_kwargs): + yield from _gen_messages_text_only("hello") + + runtime = _Runtime() + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="cred", + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + messages, final_value = _drain_generator(gen) + + assert runtime.runtime.credentials == {"token": "t"} + assert [m.message.text for m in messages] == ["hello"] + assert final_value == {} + + +def test_stream_online_results_raises_when_missing_params(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def get_online_document_page_content(self, **_kwargs): + yield from _gen_messages_text_only("never") + + def online_drive_download_file(self, **_kwargs): + yield from _gen_messages_text_only("never") + + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=_Runtime()) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={}, + ) + + with pytest.raises(ValueError, match="datasource_param is required for ONLINE_DOCUMENT streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + with pytest.raises(ValueError, match="online_drive_request is required for ONLINE_DRIVE streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + +def test_stream_online_results_yields_messages_and_returns_empty_dict_online_drive(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def online_drive_download_file(self, **_kwargs): + yield from _gen_messages_text_only("drive") + + runtime = _Runtime() + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="cred", + datasource_param=None, + online_drive_request=types.SimpleNamespace(id="fid", bucket="b"), + ) + messages, final_value = _drain_generator(gen) + + assert runtime.runtime.credentials == {"token": "t"} + assert [m.message.text for m in messages] == ["drive"] + assert final_value == {} + + +def test_stream_online_results_raises_for_unsupported_stream_type(mocker): + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=mocker.Mock()) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={}, + ) + + with pytest.raises(ValueError, match="Unsupported datasource type for streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="website_crawl", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + def test_stream_node_events_emits_events_online_document(mocker): # make manager's low-level stream produce TEXT only mocker.patch.object( @@ -93,6 +370,260 @@ def test_stream_node_events_emits_events_online_document(mocker): assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED +def test_stream_node_events_builds_file_and_variables_from_messages(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"), + meta={}, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage(text="hello"), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text="http://example.com"), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="a", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="b", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="x", variable_value=1, stream=False), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.JSON, + message=DatasourceMessage.JsonMessage(json_object={"k": "v"}), + meta=None, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + fake_tool_file = types.SimpleNamespace(mimetype="image/png") + + class _Session: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def scalar(self, _stmt): + return fake_tool_file + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) + mocker.patch( + "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE + ) + built = File( + tenant_id="t1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="tool_file_1", + extension=".png", + mime_type="image/png", + storage_key="k", + ) + build_from_mapping = mocker.patch( + "core.datasource.datasource_manager.file_factory.build_from_mapping", + return_value=built, + ) + + variable_pool = mocker.Mock() + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={"k": "v"}, + datasource_info={"info": "x"}, + variable_pool=variable_pool, + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + + build_from_mapping.assert_called_once() + variable_pool.add.assert_not_called() + + assert any(isinstance(e, StreamChunkEvent) and e.chunk == "hello" for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.chunk.startswith("Link: http") for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "a" for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "b" for e in events) + assert isinstance(events[-2], StreamChunkEvent) + assert events[-2].is_final is True + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.outputs["v"] == "ab" + assert events[-1].node_run_result.outputs["x"] == 1 + + +def test_stream_node_events_raises_when_toolfile_missing(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/missing.png"), + meta={}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + class _Session: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def scalar(self, _stmt): + return None + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) + + with pytest.raises(ValueError, match="ToolFile not found for file_id=missing, tenant_id=t1"): + list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={}, + variable_pool=mocker.Mock(), + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + + +def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + file_in = File( + tenant_id="t1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="tf", + extension=".pdf", + mime_type="application/pdf", + storage_key="k", + ) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.FileMessage(file_marker="file_marker"), + meta={"file": file_in}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + variable_pool = mocker.Mock() + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={"k": "v"}, + variable_pool=variable_pool, + datasource_param=None, + online_drive_request=types.SimpleNamespace(id="id", bucket="b"), + ) + ) + + variable_pool.add.assert_called_once() + assert variable_pool.add.call_args[0][0] == ["nodeA", "file"] + assert variable_pool.add.call_args[0][1] == file_in + + completed = events[-1] + assert isinstance(completed, StreamCompletedEvent) + assert completed.node_run_result.outputs["file"] == file_in + assert completed.node_run_result.outputs["datasource_type"] == DatasourceProviderType.ONLINE_DRIVE + + +def test_stream_node_events_skips_file_build_for_non_online_types(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"), + meta={}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + build_from_mapping = mocker.patch("core.datasource.datasource_manager.file_factory.build_from_mapping") + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="website_crawl", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={}, + variable_pool=mocker.Mock(), + datasource_param=None, + online_drive_request=None, + ) + ) + + build_from_mapping.assert_not_called() + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.outputs["file"] is None + + def test_get_upload_file_by_id_builds_file(mocker): # fake UploadFile row fake_row = types.SimpleNamespace( @@ -133,3 +664,27 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + + +def test_get_upload_file_by_id_raises_when_missing(mocker): + class _Q: + def where(self, *_args, **_kwargs): + return self + + def first(self): + return None + + class _S: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def query(self, *_): + return _Q() + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S()) + + with pytest.raises(ValueError, match="UploadFile not found for file_id=fid, tenant_id=t1"): + DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") diff --git a/api/tests/unit_tests/core/datasource/test_errors.py b/api/tests/unit_tests/core/datasource/test_errors.py new file mode 100644 index 0000000000..95986415b1 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_errors.py @@ -0,0 +1,64 @@ +from unittest.mock import MagicMock + +from core.datasource.entities.datasource_entities import DatasourceInvokeMeta +from core.datasource.errors import ( + DatasourceApiSchemaError, + DatasourceEngineInvokeError, + DatasourceInvokeError, + DatasourceNotFoundError, + DatasourceNotSupportedError, + DatasourceParameterValidationError, + DatasourceProviderCredentialValidationError, + DatasourceProviderNotFoundError, +) + + +class TestErrors: + def test_datasource_provider_not_found_error(self): + error = DatasourceProviderNotFoundError("Provider not found") + assert str(error) == "Provider not found" + assert isinstance(error, ValueError) + + def test_datasource_not_found_error(self): + error = DatasourceNotFoundError("Datasource not found") + assert str(error) == "Datasource not found" + assert isinstance(error, ValueError) + + def test_datasource_parameter_validation_error(self): + error = DatasourceParameterValidationError("Validation failed") + assert str(error) == "Validation failed" + assert isinstance(error, ValueError) + + def test_datasource_provider_credential_validation_error(self): + error = DatasourceProviderCredentialValidationError("Credential validation failed") + assert str(error) == "Credential validation failed" + assert isinstance(error, ValueError) + + def test_datasource_not_supported_error(self): + error = DatasourceNotSupportedError("Not supported") + assert str(error) == "Not supported" + assert isinstance(error, ValueError) + + def test_datasource_invoke_error(self): + error = DatasourceInvokeError("Invoke error") + assert str(error) == "Invoke error" + assert isinstance(error, ValueError) + + def test_datasource_api_schema_error(self): + error = DatasourceApiSchemaError("API schema error") + assert str(error) == "API schema error" + assert isinstance(error, ValueError) + + def test_datasource_engine_invoke_error(self): + mock_meta = MagicMock(spec=DatasourceInvokeMeta) + error = DatasourceEngineInvokeError(meta=mock_meta) + assert error.meta == mock_meta + assert isinstance(error, Exception) + + def test_datasource_engine_invoke_error_init(self): + # Test initialization with meta + meta = DatasourceInvokeMeta(time_cost=1.5, error="Engine failed") + error = DatasourceEngineInvokeError(meta=meta) + assert error.meta == meta + assert error.meta.time_cost == 1.5 + assert error.meta.error == "Engine failed" diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py new file mode 100644 index 0000000000..43f582feb7 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -0,0 +1,337 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType +from models.tools import ToolFile + + +class TestDatasourceFileMessageTransformer: + def test_transform_text_and_link_messages(self): + # Setup + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, message=DatasourceMessage.TextMessage(text="hello") + ), + DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text="https://example.com"), + ), + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 2 + assert result[0].type == DatasourceMessage.MessageType.TEXT + assert result[0].message.text == "hello" + assert result[1].type == DatasourceMessage.MessageType.LINK + assert result[1].message.text == "https://example.com" + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + def test_transform_image_message_success(self, mock_guess_ext, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "file_id_123" + mock_tool_file.mimetype = "image/png" + mock_manager.create_file_by_url.return_value = mock_tool_file + mock_guess_ext.return_value = ".png" + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, + message=DatasourceMessage.TextMessage(text="https://example.com/image.png"), + meta={"some": "meta"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1", conversation_id="conv1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/file_id_123.png" + assert result[0].meta == {"some": "meta"} + mock_manager.create_file_by_url.assert_called_once_with( + user_id="user1", tenant_id="tenant1", file_url="https://example.com/image.png", conversation_id="conv1" + ) + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + def test_transform_image_message_failure(self, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_manager.create_file_by_url.side_effect = Exception("Download failed") + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, + message=DatasourceMessage.TextMessage(text="https://example.com/image.png"), + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.TEXT + assert "Failed to download image" in result[0].message.text + assert "Download failed" in result[0].message.text + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + def test_transform_blob_message_image(self, mock_guess_ext, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_456" + mock_tool_file.mimetype = "image/jpeg" + mock_manager.create_file_by_raw.return_value = mock_tool_file + mock_guess_ext.return_value = ".jpg" + + blob_data = b"fake-image-bits" + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=blob_data), + meta={"mime_type": "image/jpeg", "file_name": "test.jpg"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/blob_id_456.jpg" + mock_manager.create_file_by_raw.assert_called_once() + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + @patch("core.datasource.utils.message_transformer.guess_type") + def test_transform_blob_message_binary_guess_mimetype( + self, mock_guess_type, mock_guess_ext, mock_tool_file_manager_cls + ): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_789" + mock_tool_file.mimetype = "application/pdf" + mock_manager.create_file_by_raw.return_value = mock_tool_file + mock_guess_type.return_value = ("application/pdf", None) + mock_guess_ext.return_value = ".pdf" + + blob_data = b"fake-pdf-bits" + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=blob_data), + meta={"file_name": "test.pdf"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK + assert result[0].message.text == "/files/datasources/blob_id_789.pdf" + + def test_transform_blob_message_invalid_type(self): + # Setup + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, message=DatasourceMessage.TextMessage(text="not a blob") + ) + ] + + # Execute & Verify + with pytest.raises(ValueError, match="unexpected message type"): + list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + def test_transform_file_tool_file_image(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.TOOL_FILE + mock_file.related_id = "related_123" + mock_file.extension = ".png" + mock_file.type = FileType.IMAGE + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="ignored"), + meta={"file": mock_file}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/related_123.png" + + def test_transform_file_tool_file_binary(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.TOOL_FILE + mock_file.related_id = "related_456" + mock_file.extension = ".txt" + mock_file.type = FileType.DOCUMENT + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="ignored"), + meta={"file": mock_file}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.LINK + assert result[0].message.text == "/files/datasources/related_456.txt" + + def test_transform_file_other_transfer_method(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.REMOTE_URL + + msg = DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="remote image"), + meta={"file": mock_file}, + ) + messages = [msg] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0] == msg + + def test_transform_other_message_type(self): + # JSON type is yielded by the default 'else' block or the 'yield message' at the end + msg = DatasourceMessage( + type=DatasourceMessage.MessageType.JSON, message=DatasourceMessage.JsonMessage(json_object={"k": "v"}) + ) + messages = [msg] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0] == msg + + def test_get_datasource_file_url(self): + # Test with extension + url = DatasourceFileMessageTransformer.get_datasource_file_url("file1", ".jpg") + assert url == "/files/datasources/file1.jpg" + + # Test without extension + url = DatasourceFileMessageTransformer.get_datasource_file_url("file2", None) + assert url == "/files/datasources/file2.bin" + + def test_transform_blob_message_no_meta_filename(self): + # This tests line 70 where filename might be None + with patch("core.datasource.utils.message_transformer.ToolFileManager") as mock_tool_file_manager_cls: + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_no_name" + mock_tool_file.mimetype = "application/octet-stream" + mock_manager.create_file_by_raw.return_value = mock_tool_file + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=b"data"), + meta={}, # No mime_type, no file_name + ) + ] + + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK + assert result[0].message.text == "/files/datasources/blob_id_no_name.bin" + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + def test_transform_image_message_not_text_message(self, mock_tool_file_manager_cls): + # This tests line 24-26 where it checks if message is instance of TextMessage + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, message=DatasourceMessage.BlobMessage(blob=b"not-text") + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify - should yield unchanged if it's not a TextMessage + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE + assert isinstance(result[0].message, DatasourceMessage.BlobMessage) diff --git a/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py new file mode 100644 index 0000000000..2945eb5523 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py @@ -0,0 +1,101 @@ +from collections.abc import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + WebsiteCrawlMessage, +) +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin + + +class TestWebsiteCrawlDatasourcePlugin: + @pytest.fixture + def mock_entity(self): + entity = MagicMock(spec=DatasourceEntity) + entity.identity = MagicMock() + entity.identity.provider = "test-provider" + entity.identity.name = "test-name" + return entity + + @pytest.fixture + def mock_runtime(self): + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test-key"} + return runtime + + def test_init(self, mock_entity, mock_runtime): + # Arrange + tenant_id = "test-tenant-id" + icon = "test-icon" + plugin_unique_identifier = "test-plugin-id" + + # Act + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.tenant_id == tenant_id + assert plugin.plugin_unique_identifier == plugin_unique_identifier + assert plugin.entity == mock_entity + assert plugin.runtime == mock_runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self, mock_entity, mock_runtime): + # Arrange + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.datasource_provider_type() == DatasourceProviderType.WEBSITE_CRAWL + + def test_get_website_crawl(self, mock_entity, mock_runtime): + # Arrange + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id="test-tenant-id", + icon="test-icon", + plugin_unique_identifier="test-plugin-id", + ) + + user_id = "test-user-id" + datasource_parameters = {"url": "https://example.com"} + provider_type = "firecrawl" + + mock_message = MagicMock(spec=WebsiteCrawlMessage) + + # Mock PluginDatasourceManager + with patch("core.datasource.website_crawl.website_crawl_plugin.PluginDatasourceManager") as mock_manager_class: + mock_manager = mock_manager_class.return_value + mock_manager.get_website_crawl.return_value = (msg for msg in [mock_message]) + + # Act + result = plugin.get_website_crawl( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert isinstance(result, Generator) + messages = list(result) + assert len(messages) == 1 + assert messages[0] == mock_message + + mock_manager.get_website_crawl.assert_called_once_with( + tenant_id="test-tenant-id", + user_id=user_id, + datasource_provider="test-provider", + datasource_name="test-name", + credentials={"api_key": "test-key"}, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) diff --git a/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py new file mode 100644 index 0000000000..b7822ba800 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py @@ -0,0 +1,95 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController + + +class TestWebsiteCrawlDatasourcePluginProviderController: + @pytest.fixture + def mock_entity(self): + entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + entity.datasources = [] + entity.identity = MagicMock() + entity.identity.icon = "test-icon" + return entity + + def test_init(self, mock_entity): + # Arrange + plugin_id = "test-plugin-id" + plugin_unique_identifier = "test-unique-id" + tenant_id = "test-tenant-id" + + # Act + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self, mock_entity): + # Arrange + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.WEBSITE_CRAWL + + def test_get_datasource_success(self, mock_entity): + # Arrange + datasource_name = "test-datasource" + tenant_id = "test-tenant-id" + plugin_unique_identifier = "test-unique-id" + + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity = MagicMock() + mock_datasource_entity.identity.name = datasource_name + mock_entity.datasources = [mock_datasource_entity] + + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + with patch( + "core.datasource.website_crawl.website_crawl_provider.WebsiteCrawlDatasourcePlugin" + ) as mock_plugin_class: + mock_plugin_instance = mock_plugin_class.return_value + result = controller.get_datasource(datasource_name) + + # Assert + assert result == mock_plugin_instance + mock_plugin_class.assert_called_once() + args, kwargs = mock_plugin_class.call_args + assert kwargs["entity"] == mock_datasource_entity + assert isinstance(kwargs["runtime"], DatasourceRuntime) + assert kwargs["runtime"].tenant_id == tenant_id + assert kwargs["tenant_id"] == tenant_id + assert kwargs["icon"] == "test-icon" + assert kwargs["plugin_unique_identifier"] == plugin_unique_identifier + + def test_get_datasource_not_found(self, mock_entity): + # Arrange + datasource_name = "non-existent" + mock_entity.datasources = [] + + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Act & Assert + with pytest.raises(ValueError, match=f"Datasource with name {datasource_name} not found"): + controller.get_datasource(datasource_name) diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py new file mode 100644 index 0000000000..7660967183 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -0,0 +1,329 @@ +"""Tests for OpikDataTrace workflow_trace changes. + +Covers: +- _seed_to_uuid4 helper: produces valid UUID4 strings deterministically +- prepare_opik_uuid helper: basic contract +- workflow_trace without message_id now creates a root span parented to None +- workflow_trace without message_id: node spans parent to root_span_id (not workflow_app_log_id) +- workflow_trace with message_id still creates root span keyed on workflow_run_id (unchanged path) +""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from unittest.mock import MagicMock, patch + +from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo +from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + +# A stable UUID4 used as the workflow_run_id throughout all tests. +_WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_workflow_trace_info( + *, + message_id: str | None = None, + workflow_app_log_id: str | None = None, + workflow_run_id: str = _WORKFLOW_RUN_ID, +) -> WorkflowTraceInfo: + """Return a minimal WorkflowTraceInfo suitable for unit testing.""" + return WorkflowTraceInfo( + message_id=message_id, + workflow_id="wf-id", + tenant_id="tenant-id", + workflow_run_id=workflow_run_id, + workflow_app_log_id=workflow_app_log_id, + workflow_run_elapsed_time=1.5, + workflow_run_status="succeeded", + workflow_run_inputs={"query": "hello"}, + workflow_run_outputs={"result": "world"}, + workflow_run_version="1", + total_tokens=42, + file_list=[], + query="hello", + start_time=datetime(2025, 1, 1, 12, 0, 0), + end_time=datetime(2025, 1, 1, 12, 0, 1), + metadata={"app_id": "app-abc"}, + conversation_id=None, + ) + + +def _make_opik_trace_instance() -> OpikDataTrace: + """Construct an OpikDataTrace with the Opik SDK client mocked out.""" + with patch("core.ops.opik_trace.opik_trace.Opik"): + from core.ops.entities.config_entity import OpikConfig + + config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") + instance = OpikDataTrace(config) + + instance.add_trace = MagicMock(return_value=MagicMock(id="mock-trace-id")) + instance.add_span = MagicMock() + instance.get_service_account_with_tenant = MagicMock(return_value=MagicMock()) + return instance + + +# --------------------------------------------------------------------------- +# _seed_to_uuid4 +# --------------------------------------------------------------------------- + + +class TestSeedToUuid4: + def test_returns_valid_uuid4_string(self): + result = _seed_to_uuid4("some-arbitrary-seed") + parsed = uuid.UUID(result) + assert parsed.version == 4 + + def test_is_deterministic(self): + assert _seed_to_uuid4("seed-abc") == _seed_to_uuid4("seed-abc") + + def test_different_seeds_give_different_results(self): + assert _seed_to_uuid4("seed-1") != _seed_to_uuid4("seed-2") + + def test_workflow_run_id_with_root_suffix_is_valid_uuid4(self): + """The primary use-case: deriving a root-span UUID from workflow_run_id + '-root'.""" + seed = _WORKFLOW_RUN_ID + "-root" + result = _seed_to_uuid4(seed) + parsed = uuid.UUID(result) + assert parsed.version == 4 + + def test_seed_and_seed_root_produce_different_uuids(self): + """Root span UUID must differ from the base workflow UUID to avoid ID collisions.""" + base = _seed_to_uuid4(_WORKFLOW_RUN_ID) + with_root = _seed_to_uuid4(_WORKFLOW_RUN_ID + "-root") + assert base != with_root + + +# --------------------------------------------------------------------------- +# prepare_opik_uuid +# --------------------------------------------------------------------------- + + +class TestPrepareOpikUuid: + def test_is_deterministic(self): + dt = datetime(2025, 6, 15, 10, 30, 0) + uid = str(uuid.uuid4()) + assert prepare_opik_uuid(dt, uid) == prepare_opik_uuid(dt, uid) + + def test_different_uuids_give_different_results(self): + dt = datetime(2025, 6, 15, 10, 30, 0) + assert prepare_opik_uuid(dt, str(uuid.uuid4())) != prepare_opik_uuid(dt, str(uuid.uuid4())) + + def test_none_datetime_does_not_raise(self): + assert prepare_opik_uuid(None, str(uuid.uuid4())) is not None + + def test_none_uuid_does_not_raise(self): + assert prepare_opik_uuid(datetime(2025, 1, 1), None) is not None + + +# --------------------------------------------------------------------------- +# workflow_trace — no message_id (new code path) +# --------------------------------------------------------------------------- + + +class TestWorkflowTraceWithoutMessageId: + def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): + instance = _make_opik_trace_instance() + fake_repo = MagicMock() + fake_repo.get_by_workflow_run.return_value = node_executions or [] + + with ( + patch("core.ops.opik_trace.opik_trace.db") as mock_db, + patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch( + "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=fake_repo, + ), + ): + mock_db.engine = MagicMock() + instance.workflow_trace(trace_info) + + return instance + + def _expected_root_span_id(self, trace_info: WorkflowTraceInfo): + return prepare_opik_uuid( + trace_info.start_time, + _seed_to_uuid4(trace_info.workflow_run_id + "-root"), + ) + + def test_root_span_is_created(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + assert instance.add_span.called + + def test_root_span_id_matches_expected(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + expected = self._expected_root_span_id(trace_info) + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["id"] == expected + + def test_root_span_has_no_parent(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["parent_span_id"] is None + + def test_trace_name_is_workflow_trace(self): + """Without message_id, the Opik trace itself should be named WORKFLOW_TRACE.""" + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + trace_kwargs = instance.add_trace.call_args_list[0][0][0] + assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE + + def test_root_span_name_is_workflow_trace(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE + + def test_root_span_has_workflow_tag(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert "workflow" in root_span_kwargs["tags"] + + def test_node_execution_spans_are_parented_to_root(self): + """Node spans must use root_span_id as parent, not any other ID.""" + trace_info = _make_workflow_trace_info(message_id=None) + expected_root_span_id = self._expected_root_span_id(trace_info) + + node_exec = MagicMock() + node_exec.id = str(uuid.uuid4()) + node_exec.title = "LLM Node" + node_exec.node_type = "llm" + node_exec.status = "succeeded" + node_exec.process_data = {} + node_exec.inputs = {"prompt": "hi"} + node_exec.outputs = {"text": "hello"} + node_exec.created_at = datetime(2025, 1, 1, 12, 0, 0) + node_exec.elapsed_time = 0.5 + node_exec.metadata = {} + + instance = self._run(trace_info, node_executions=[node_exec]) + + # call_args_list[0] = root span, [1] = node execution span + assert instance.add_span.call_count == 2 + node_span_kwargs = instance.add_span.call_args_list[1][0][0] + assert node_span_kwargs["parent_span_id"] == expected_root_span_id + + def test_node_span_not_parented_to_workflow_app_log_id(self): + """Old behaviour derived parent from workflow_app_log_id; that must no longer apply.""" + trace_info = _make_workflow_trace_info( + message_id=None, + workflow_app_log_id=str(uuid.uuid4()), + ) + + node_exec = MagicMock() + node_exec.id = str(uuid.uuid4()) + node_exec.title = "Tool Node" + node_exec.node_type = "tool" + node_exec.status = "succeeded" + node_exec.process_data = {} + node_exec.inputs = {} + node_exec.outputs = {} + node_exec.created_at = datetime(2025, 1, 1, 12, 0, 0) + node_exec.elapsed_time = 0.2 + node_exec.metadata = {} + + instance = self._run(trace_info, node_executions=[node_exec]) + + old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id) + node_span_kwargs = instance.add_span.call_args_list[1][0][0] + assert node_span_kwargs["parent_span_id"] != old_parent_id + + def test_root_span_id_differs_from_trace_id(self): + """The root span must have a different ID from the Opik trace to maintain correct hierarchy.""" + trace_info = _make_workflow_trace_info(message_id=None) + dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id + opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) + root_span_id = self._expected_root_span_id(trace_info) + assert root_span_id != opik_trace_id + + +# --------------------------------------------------------------------------- +# workflow_trace — with message_id (unchanged path, guard against regression) +# --------------------------------------------------------------------------- + + +class TestWorkflowTraceWithMessageId: + _MESSAGE_ID = str(uuid.uuid4()) + + def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): + instance = _make_opik_trace_instance() + fake_repo = MagicMock() + fake_repo.get_by_workflow_run.return_value = node_executions or [] + + with ( + patch("core.ops.opik_trace.opik_trace.db") as mock_db, + patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch( + "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=fake_repo, + ), + ): + mock_db.engine = MagicMock() + instance.workflow_trace(trace_info) + + return instance + + def test_trace_name_is_message_trace(self): + """With message_id, the Opik trace should be named MESSAGE_TRACE.""" + trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) + instance = self._run(trace_info) + + trace_kwargs = instance.add_trace.call_args_list[0][0][0] + assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE + + def test_root_span_uses_workflow_run_id_directly(self): + """When message_id is set, root_span_id = prepare_opik_uuid(start_time, workflow_run_id).""" + trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) + instance = self._run(trace_info) + + expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["id"] == expected_root_span_id + + def test_root_span_id_differs_from_no_message_id_case(self): + """The two branches must produce different root span IDs for the same workflow_run_id.""" + id_with_message = prepare_opik_uuid( + datetime(2025, 1, 1, 12, 0, 0), + _WORKFLOW_RUN_ID, + ) + id_without_message = prepare_opik_uuid( + datetime(2025, 1, 1, 12, 0, 0), + _seed_to_uuid4(_WORKFLOW_RUN_ID + "-root"), + ) + assert id_with_message != id_without_message + + def test_node_spans_parented_to_workflow_run_root_span(self): + """Node spans must still parent to root_span_id derived from workflow_run_id.""" + trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) + expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) + + node_exec = MagicMock() + node_exec.id = str(uuid.uuid4()) + node_exec.title = "LLM" + node_exec.node_type = "llm" + node_exec.status = "succeeded" + node_exec.process_data = {} + node_exec.inputs = {} + node_exec.outputs = {} + node_exec.created_at = datetime(2025, 1, 1, 12, 0, 0) + node_exec.elapsed_time = 0.3 + node_exec.metadata = {} + + instance = self._run(trace_info, node_executions=[node_exec]) + + node_span_kwargs = instance.add_span.call_args_list[1][0][0] + assert node_span_kwargs["parent_span_id"] == expected_root_span_id diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py new file mode 100644 index 0000000000..13285cdad0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -0,0 +1,813 @@ +""" +Unit tests for DatasetDocumentStore. + +Tests cover all public methods and error paths of the DatasetDocumentStore class +which provides document storage and retrieval functionality for datasets in the RAG system. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.docstore.dataset_docstore import DatasetDocumentStore, DocumentSegment +from core.rag.models.document import AttachmentDocument, Document +from models.dataset import Dataset + + +class TestDatasetDocumentStoreInit: + """Tests for DatasetDocumentStore initialization.""" + + def test_init_with_all_parameters(self): + """Test initialization with dataset, user_id, and document_id.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + assert store._dataset == mock_dataset + assert store._user_id == "test-user-id" + assert store._document_id == "test-doc-id" + assert store.dataset_id == "test-dataset-id" + assert store.user_id == "test-user-id" + + def test_init_without_document_id(self): + """Test initialization without document_id.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + assert store._document_id is None + assert store.dataset_id == "test-dataset-id" + + +class TestDatasetDocumentStoreSerialization: + """Tests for to_dict and from_dict methods.""" + + def test_to_dict(self): + """Test serialization to dictionary.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.to_dict() + + assert result == {"dataset_id": "test-dataset-id"} + + def test_from_dict(self): + """Test deserialization from dictionary.""" + + config_dict = { + "dataset": MagicMock(spec=["id"]), + "user_id": "test-user", + "document_id": "test-doc", + } + config_dict["dataset"].id = "ds-123" + + store = DatasetDocumentStore.from_dict(config_dict) + + assert store._user_id == "test-user" + assert store._document_id == "test-doc" + + +class TestDatasetDocumentStoreDocs: + """Tests for the docs property.""" + + def test_docs_returns_document_dict(self): + """Test that docs property returns a dictionary of documents.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock(spec=DocumentSegment) + mock_segment.index_node_id = "node-1" + mock_segment.index_node_hash = "hash-1" + mock_segment.document_id = "doc-1" + mock_segment.dataset_id = "test-dataset-id" + mock_segment.content = "Test content" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalars.return_value.all.return_value = [mock_segment] + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.docs + + assert "node-1" in result + assert isinstance(result["node-1"], Document) + + def test_docs_empty_dataset(self): + """Test docs property with no segments.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalars.return_value.all.return_value = [] + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.docs + + assert result == {} + + +class TestDatasetDocumentStoreAddDocuments: + """Tests for add_documents method.""" + + def test_add_documents_new_document_with_embedding(self): + """Test adding new documents with embedding model.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "provider" + mock_dataset.embedding_model = "model" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_model_instance = MagicMock() + mock_model_instance.get_text_embedding_num_tokens.return_value = [10] + + with ( + patch("core.rag.docstore.dataset_docstore.db") as mock_db, + patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + ): + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + + mock_manager = MagicMock() + mock_manager.get_model_instance.return_value = mock_model_instance + mock_manager_class.return_value = mock_manager + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.add.assert_called() + mock_db.session.commit.assert_called() + + def test_add_documents_update_existing_document(self): + """Test updating existing document with allow_update=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + mock_dataset.embedding_model_provider = None + mock_dataset.embedding_model = None + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Updated content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "new-hash", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_existing_segment = MagicMock() + mock_existing_segment.id = "seg-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.commit.assert_called() + + def test_add_documents_raises_when_not_allowed(self): + """Test that adding existing doc without allow_update raises ValueError.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_existing_segment = MagicMock() + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + with pytest.raises(ValueError, match="already exists"): + store.add_documents([mock_doc], allow_update=False) + + def test_add_documents_with_answer_metadata(self): + """Test adding document with answer in metadata.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + "answer": "Test answer", + } + mock_doc.attachments = None + mock_doc.children = None + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.add.assert_called() + + def test_add_documents_with_invalid_document_type(self): + """Test that non-Document raises ValueError.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + with pytest.raises(ValueError, match="must be a Document"): + store.add_documents(["not a document"]) + + def test_add_documents_with_none_metadata(self): + """Test that document with None metadata raises ValueError.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = None + + with patch("core.rag.docstore.dataset_docstore.db"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + with pytest.raises(ValueError, match="metadata must be a dict"): + store.add_documents([mock_doc]) + + def test_add_documents_with_save_child(self): + """Test adding documents with save_child=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_child = MagicMock(spec=Document) + mock_child.page_content = "Child content" + mock_child.metadata = { + "doc_id": "child-1", + "doc_hash": "child-hash", + } + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + } + mock_doc.attachments = None + mock_doc.children = [mock_child] + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc], save_child=True) + + mock_db.session.add.assert_called() + + +class TestDatasetDocumentStoreExists: + """Tests for document_exists method.""" + + def test_document_exists_returns_true(self): + """Test document_exists returns True when segment exists.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.document_exists("doc-1") + + assert result is True + + def test_document_exists_returns_false(self): + """Test document_exists returns False when segment doesn't exist.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.document_exists("doc-1") + + assert result is False + + +class TestDatasetDocumentStoreGetDocument: + """Tests for get_document method.""" + + def test_get_document_success(self): + """Test getting a document successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock(spec=DocumentSegment) + mock_segment.index_node_id = "node-1" + mock_segment.index_node_hash = "hash-1" + mock_segment.document_id = "doc-1" + mock_segment.dataset_id = "test-dataset-id" + mock_segment.content = "Test content" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document("node-1", raise_error=False) + + assert isinstance(result, Document) + assert result.page_content == "Test content" + + def test_get_document_returns_none_when_not_found(self): + """Test get_document returns None when not found and raise_error=False.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document("nonexistent", raise_error=False) + + assert result is None + + def test_get_document_raises_when_not_found(self): + """Test get_document raises ValueError when not found and raise_error=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + with pytest.raises(ValueError, match="not found"): + store.get_document("nonexistent", raise_error=True) + + +class TestDatasetDocumentStoreDeleteDocument: + """Tests for delete_document method.""" + + def test_delete_document_success(self): + """Test deleting a document successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + store.delete_document("doc-1") + + mock_db.session.delete.assert_called_with(mock_segment) + mock_db.session.commit.assert_called() + + def test_delete_document_returns_none_when_not_found(self): + """Test delete_document returns None when not found and raise_error=False.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.delete_document("nonexistent", raise_error=False) + + assert result is None + + def test_delete_document_raises_when_not_found(self): + """Test delete_document raises ValueError when not found and raise_error=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + with pytest.raises(ValueError, match="not found"): + store.delete_document("nonexistent", raise_error=True) + + +class TestDatasetDocumentStoreHashOperations: + """Tests for set_document_hash and get_document_hash methods.""" + + def test_set_document_hash_success(self): + """Test setting document hash successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + mock_segment.index_node_hash = "old-hash" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + store.set_document_hash("doc-1", "new-hash") + + assert mock_segment.index_node_hash == "new-hash" + mock_db.session.commit.assert_called() + + def test_set_document_hash_returns_none_when_not_found(self): + """Test set_document_hash returns None when segment not found.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.set_document_hash("nonexistent", "new-hash") + + assert result is None + + def test_get_document_hash_success(self): + """Test getting document hash successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + mock_segment.index_node_hash = "test-hash" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_hash("doc-1") + + assert result == "test-hash" + + def test_get_document_hash_returns_none_when_not_found(self): + """Test get_document_hash returns None when segment not found.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_hash("nonexistent") + + assert result is None + + +class TestDatasetDocumentStoreSegment: + """Tests for get_document_segment method.""" + + def test_get_document_segment_returns_segment(self): + """Test getting a document segment.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock(spec=DocumentSegment) + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalar.return_value = mock_segment + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_segment("doc-1") + + assert result == mock_segment + + def test_get_document_segment_returns_none(self): + """Test getting a non-existent document segment.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalar.return_value = None + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_segment("nonexistent") + + assert result is None + + +class TestDatasetDocumentStoreMultimodelBinding: + """Tests for add_multimodel_documents_binding method.""" + + def test_add_multimodel_documents_binding_with_attachments(self): + """Test adding multimodel document bindings.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + mock_attachment = MagicMock(spec=AttachmentDocument) + mock_attachment.metadata = {"doc_id": "attachment-1"} + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_multimodel_documents_binding("seg-1", [mock_attachment]) + + mock_db.session.add.assert_called() + + def test_add_multimodel_documents_binding_without_attachments(self): + """Test adding bindings with None attachments.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_multimodel_documents_binding("seg-1", None) + + mock_db.session.add.assert_not_called() + + def test_add_multimodel_documents_binding_with_empty_list(self): + """Test adding bindings with empty list.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_multimodel_documents_binding("seg-1", []) + + mock_db.session.add.assert_not_called() + + +class TestDatasetDocumentStoreAddDocumentsUpdateChild: + """Tests for add_documents when updating existing documents with children.""" + + def test_add_documents_update_existing_with_children(self): + """Test updating existing document with save_child=True and children.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_child = MagicMock(spec=Document) + mock_child.page_content = "Updated child content" + mock_child.metadata = { + "doc_id": "child-1", + "doc_hash": "new-child-hash", + } + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Updated content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "new-hash", + } + mock_doc.attachments = None + mock_doc.children = [mock_child] + + mock_existing_segment = MagicMock() + mock_existing_segment.id = "seg-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc], save_child=True) + + mock_db.session.query.return_value.where.return_value.delete.assert_called() + mock_db.session.commit.assert_called() + + +class TestDatasetDocumentStoreAddDocumentsUpdateAnswer: + """Tests for add_documents when updating existing documents with answer metadata.""" + + def test_add_documents_update_existing_with_answer(self): + """Test updating existing document with answer in metadata.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Updated content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "new-hash", + "answer": "Updated answer", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_existing_segment = MagicMock() + mock_existing_segment.id = "seg-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.commit.assert_called() diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py new file mode 100644 index 0000000000..a0db25174d --- /dev/null +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -0,0 +1,555 @@ +"""Unit tests for cached_embedding.py - CacheEmbedding class. + +This test file covers the methods not fully tested in test_embedding_service.py: +- embed_multimodal_documents +- embed_multimodal_query +- Error handling scenarios in embed_query (DEBUG mode) +""" + +import base64 +from decimal import Decimal +from unittest.mock import Mock, patch + +import numpy as np +import pytest +from sqlalchemy.exc import IntegrityError + +from core.rag.embedding.cached_embedding import CacheEmbedding +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from models.dataset import Embedding + + +class TestCacheEmbeddingMultimodalDocuments: + """Test suite for CacheEmbedding.embed_multimodal_documents method.""" + + @pytest.fixture + def mock_model_instance(self): + """Create a mock ModelInstance for testing.""" + model_instance = Mock() + model_instance.model = "vision-embedding-model" + model_instance.provider = "openai" + model_instance.credentials = {"api_key": "test-key"} + + model_type_instance = Mock() + model_instance.model_type_instance = model_type_instance + + model_schema = Mock() + model_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} + model_type_instance.get_model_schema.return_value = model_schema + + return model_instance + + @pytest.fixture + def sample_multimodal_result(self): + """Create a sample multimodal EmbeddingResult.""" + embedding_vector = np.random.randn(1536) + normalized_vector = (embedding_vector / np.linalg.norm(embedding_vector)).tolist() + + usage = EmbeddingUsage( + tokens=10, + total_tokens=10, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000001"), + currency="USD", + latency=0.5, + ) + + return EmbeddingResult( + model="vision-embedding-model", + embeddings=[normalized_vector], + usage=usage, + ) + + def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): + """Test embedding a single multimodal document when cache is empty.""" + cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + documents = [{"file_id": "file123", "content": "test content"}] + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 1 + assert isinstance(result[0], list) + assert len(result[0]) == 1536 + + mock_model_instance.invoke_multimodal_embedding.assert_called_once() + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + def test_embed_multiple_multimodal_documents_cache_miss(self, mock_model_instance): + """Test embedding multiple multimodal documents when cache is empty.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [ + {"file_id": "file1", "content": "content 1"}, + {"file_id": "file2", "content": "content 2"}, + {"file_id": "file3", "content": "content 3"}, + ] + + embeddings = [] + for _ in range(3): + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + embeddings.append(normalized) + + usage = EmbeddingUsage( + tokens=30, + total_tokens=30, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000003"), + currency="USD", + latency=0.8, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=embeddings, + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 3 + assert all(len(emb) == 1536 for emb in result) + + def test_embed_multimodal_documents_cache_hit(self, mock_model_instance): + """Test embedding multimodal documents when embeddings are cached.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "file123"}] + + cached_vector = np.random.randn(1536) + normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist() + + mock_cached_embedding = Mock(spec=Embedding) + mock_cached_embedding.get_embedding.return_value = normalized_cached + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 1 + assert result[0] == normalized_cached + mock_model_instance.invoke_multimodal_embedding.assert_not_called() + + def test_embed_multimodal_documents_partial_cache_hit(self, mock_model_instance): + """Test embedding multimodal documents with mixed cache hits and misses.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [ + {"file_id": "cached_file"}, + {"file_id": "new_file_1"}, + {"file_id": "new_file_2"}, + ] + + cached_vector = np.random.randn(1536) + normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist() + + mock_cached_embedding = Mock(spec=Embedding) + mock_cached_embedding.get_embedding.return_value = normalized_cached + + new_embeddings = [] + for _ in range(2): + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + new_embeddings.append(normalized) + + usage = EmbeddingUsage( + tokens=20, + total_tokens=20, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000002"), + currency="USD", + latency=0.6, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=new_embeddings, + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + call_count = [0] + + def mock_filter_by(**kwargs): + call_count[0] += 1 + mock_query = Mock() + if call_count[0] == 1: + mock_query.first.return_value = mock_cached_embedding + else: + mock_query.first.return_value = None + return mock_query + + mock_session.query.return_value.filter_by = mock_filter_by + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 3 + assert result[0] == normalized_cached + + def test_embed_multimodal_documents_nan_handling(self, mock_model_instance): + """Test handling of NaN values in multimodal embeddings.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "valid"}, {"file_id": "nan"}] + + valid_vector = np.random.randn(1536).tolist() + nan_vector = [float("nan")] * 1536 + + usage = EmbeddingUsage( + tokens=20, + total_tokens=20, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000002"), + currency="USD", + latency=0.5, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[valid_vector, nan_vector], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 2 + assert result[0] is not None + assert result[1] is None + + mock_logger.warning.assert_called_once() + + def test_embed_multimodal_documents_large_batch(self, mock_model_instance): + """Test embedding large batch of multimodal documents respecting MAX_CHUNKS.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": f"file{i}"} for i in range(25)] + + def create_batch_result(batch_size): + embeddings = [] + for _ in range(batch_size): + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + embeddings.append(normalized) + + usage = EmbeddingUsage( + tokens=batch_size * 10, + total_tokens=batch_size * 10, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal(str(batch_size * 0.000001)), + currency="USD", + latency=0.5, + ) + + return EmbeddingResult( + model="vision-embedding-model", + embeddings=embeddings, + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)] + mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 25 + assert mock_model_instance.invoke_multimodal_embedding.call_count == 3 + + def test_embed_multimodal_documents_api_error(self, mock_model_instance): + """Test handling of API errors during multimodal embedding.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "file123"}] + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error") + + with pytest.raises(Exception) as exc_info: + cache_embedding.embed_multimodal_documents(documents) + + assert "API Error" in str(exc_info.value) + mock_session.rollback.assert_called() + + def test_embed_multimodal_documents_integrity_error_during_transform( + self, mock_model_instance, sample_multimodal_result + ): + """Test handling of IntegrityError during embedding transformation.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "file123"}] + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result + + mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None) + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 1 + mock_session.rollback.assert_called() + + +class TestCacheEmbeddingMultimodalQuery: + """Test suite for CacheEmbedding.embed_multimodal_query method.""" + + @pytest.fixture + def mock_model_instance(self): + """Create a mock ModelInstance for testing.""" + model_instance = Mock() + model_instance.model = "vision-embedding-model" + model_instance.provider = "openai" + model_instance.credentials = {"api_key": "test-key"} + return model_instance + + def test_embed_multimodal_query_cache_miss(self, mock_model_instance): + """Test embedding multimodal query when Redis cache is empty.""" + cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + document = {"file_id": "file123"} + + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[normalized], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + result = cache_embedding.embed_multimodal_query(document) + + assert isinstance(result, list) + assert len(result) == 1536 + mock_redis.setex.assert_called_once() + + def test_embed_multimodal_query_cache_hit(self, mock_model_instance): + """Test embedding multimodal query when Redis cache has the value.""" + cache_embedding = CacheEmbedding(mock_model_instance) + document = {"file_id": "file123"} + + embedding_vector = np.random.randn(1536) + vector_bytes = embedding_vector.tobytes() + encoded_vector = base64.b64encode(vector_bytes).decode("utf-8") + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = encoded_vector.encode() + + result = cache_embedding.embed_multimodal_query(document) + + assert isinstance(result, list) + assert len(result) == 1536 + mock_redis.expire.assert_called_once() + mock_model_instance.invoke_multimodal_embedding.assert_not_called() + + def test_embed_multimodal_query_nan_handling(self, mock_model_instance): + """Test handling of NaN values in multimodal query embeddings.""" + cache_embedding = CacheEmbedding(mock_model_instance) + + nan_vector = [float("nan")] * 1536 + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[nan_vector], + usage=usage, + ) + + document = {"file_id": "file123"} + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + with pytest.raises(ValueError) as exc_info: + cache_embedding.embed_multimodal_query(document) + + assert "Normalized embedding is nan" in str(exc_info.value) + + def test_embed_multimodal_query_api_error(self, mock_model_instance): + """Test handling of API errors during multimodal query embedding.""" + cache_embedding = CacheEmbedding(mock_model_instance) + document = {"file_id": "file123"} + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = False + + with pytest.raises(Exception) as exc_info: + cache_embedding.embed_multimodal_query(document) + + assert "API Error" in str(exc_info.value) + + def test_embed_multimodal_query_redis_set_error(self, mock_model_instance): + """Test handling of Redis set errors during multimodal query embedding.""" + cache_embedding = CacheEmbedding(mock_model_instance) + document = {"file_id": "file123"} + + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[normalized], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + mock_redis.setex.side_effect = RuntimeError("Redis Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = True + + with pytest.raises(RuntimeError): + cache_embedding.embed_multimodal_query(document) + + +class TestCacheEmbeddingQueryErrors: + """Test suite for error handling in CacheEmbedding.embed_query method.""" + + @pytest.fixture + def mock_model_instance(self): + """Create a mock ModelInstance for testing.""" + model_instance = Mock() + model_instance.model = "text-embedding-ada-002" + model_instance.provider = "openai" + model_instance.credentials = {"api_key": "test-key"} + return model_instance + + def test_embed_query_api_error_debug_mode(self, mock_model_instance): + """Test handling of API errors in debug mode.""" + cache_embedding = CacheEmbedding(mock_model_instance) + query = "test query" + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_text_embedding.side_effect = RuntimeError("API Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = True + + with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with pytest.raises(RuntimeError) as exc_info: + cache_embedding.embed_query(query) + + assert "API Error" in str(exc_info.value) + mock_logger.exception.assert_called() + + def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance): + """Test handling of Redis set errors in debug mode.""" + cache_embedding = CacheEmbedding(mock_model_instance) + query = "test query" + + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="text-embedding-ada-002", + embeddings=[normalized], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_text_embedding.return_value = embedding_result + mock_redis.setex.side_effect = RuntimeError("Redis Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = True + + with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with pytest.raises(RuntimeError): + cache_embedding.embed_query(query) + + mock_logger.exception.assert_called() + + +class TestCacheEmbeddingInitialization: + """Test suite for CacheEmbedding initialization.""" + + def test_initialization_with_user(self): + """Test CacheEmbedding initialization with user parameter.""" + model_instance = Mock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + + cache_embedding = CacheEmbedding(model_instance, user="test-user") + + assert cache_embedding._model_instance == model_instance + assert cache_embedding._user == "test-user" + + def test_initialization_without_user(self): + """Test CacheEmbedding initialization without user parameter.""" + model_instance = Mock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + + cache_embedding = CacheEmbedding(model_instance) + + assert cache_embedding._model_instance == model_instance + assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py new file mode 100644 index 0000000000..033933e886 --- /dev/null +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py @@ -0,0 +1,220 @@ +"""Unit tests for embedding_base.py - the abstract Embeddings base class.""" + +import asyncio +import inspect +from typing import Any + +import pytest + +from core.rag.embedding.embedding_base import Embeddings + + +class ConcreteEmbeddings(Embeddings): + """Concrete implementation of Embeddings for testing.""" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[1.0] * 10 for _ in texts] + + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: + return [[1.0] * 10 for _ in multimodel_documents] + + def embed_query(self, text: str) -> list[float]: + return [1.0] * 10 + + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: + return [1.0] * 10 + + +class TestEmbeddingsBase: + """Test suite for the abstract Embeddings base class.""" + + def test_embeddings_is_abc(self): + """Test that Embeddings is an abstract base class.""" + assert hasattr(Embeddings, "__abstractmethods__") + assert len(Embeddings.__abstractmethods__) > 0 + + def test_embed_documents_is_abstract(self): + """Test that embed_documents is an abstract method.""" + assert "embed_documents" in Embeddings.__abstractmethods__ + + def test_embed_multimodal_documents_is_abstract(self): + """Test that embed_multimodal_documents is an abstract method.""" + assert "embed_multimodal_documents" in Embeddings.__abstractmethods__ + + def test_embed_query_is_abstract(self): + """Test that embed_query is an abstract method.""" + assert "embed_query" in Embeddings.__abstractmethods__ + + def test_embed_multimodal_query_is_abstract(self): + """Test that embed_multimodal_query is an abstract method.""" + assert "embed_multimodal_query" in Embeddings.__abstractmethods__ + + def test_embed_documents_raises_not_implemented(self): + """Test that embed_documents raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_documents) + assert "raise NotImplementedError" in source + + def test_embed_multimodal_documents_raises_not_implemented(self): + """Test that embed_multimodal_documents raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_multimodal_documents) + assert "raise NotImplementedError" in source + + def test_embed_query_raises_not_implemented(self): + """Test that embed_query raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_query) + assert "raise NotImplementedError" in source + + def test_embed_multimodal_query_raises_not_implemented(self): + """Test that embed_multimodal_query raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_multimodal_query) + assert "raise NotImplementedError" in source + + def test_aembed_documents_raises_not_implemented(self): + """Test that aembed_documents raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.aembed_documents) + assert "raise NotImplementedError" in source + + def test_aembed_query_raises_not_implemented(self): + """Test that aembed_query raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.aembed_query) + assert "raise NotImplementedError" in source + + def test_concrete_implementation_works(self): + """Test that a concrete implementation of Embeddings works correctly.""" + concrete = ConcreteEmbeddings() + result = concrete.embed_documents(["test1", "test2"]) + assert len(result) == 2 + assert all(len(emb) == 10 for emb in result) + + def test_concrete_implementation_embed_query(self): + """Test concrete implementation of embed_query.""" + concrete = ConcreteEmbeddings() + result = concrete.embed_query("test query") + assert len(result) == 10 + + def test_concrete_implementation_embed_multimodal_documents(self): + """Test concrete implementation of embed_multimodal_documents.""" + concrete = ConcreteEmbeddings() + docs: list[dict[str, Any]] = [{"file_id": "file1"}, {"file_id": "file2"}] + result = concrete.embed_multimodal_documents(docs) + assert len(result) == 2 + + def test_concrete_implementation_embed_multimodal_query(self): + """Test concrete implementation of embed_multimodal_query.""" + concrete = ConcreteEmbeddings() + result = concrete.embed_multimodal_query({"file_id": "test"}) + assert len(result) == 10 + + +class TestEmbeddingsNotImplemented: + """Test that abstract methods raise NotImplementedError when called.""" + + def test_embed_query_raises_not_implemented(self): + """Test that embed_query raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_query("test") + + def test_embed_documents_raises_not_implemented(self): + """Test that embed_documents raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_documents(["test"]) + + def test_embed_multimodal_documents_raises_not_implemented(self): + """Test that embed_multimodal_documents raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_multimodal_documents([{"file_id": "test"}]) + + def test_embed_multimodal_query_raises_not_implemented(self): + """Test that embed_multimodal_query raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_multimodal_query({"file_id": "test"}) + + def test_aembed_documents_raises_not_implemented(self): + """Test that aembed_documents raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + + async def run_test(): + with pytest.raises(NotImplementedError): + await partial.aembed_documents(["test"]) + + asyncio.run(run_test()) + + def test_aembed_query_raises_not_implemented(self): + """Test that aembed_query raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + + async def run_test(): + with pytest.raises(NotImplementedError): + await partial.aembed_query("test") + + asyncio.run(run_test()) diff --git a/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py b/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py new file mode 100644 index 0000000000..eb14622d7a --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py @@ -0,0 +1,85 @@ +from io import BytesIO + +import pytest + +from core.rag.extractor.blob.blob import Blob + + +class TestBlob: + def test_requires_data_or_path(self): + with pytest.raises(ValueError, match="Either data or path must be provided"): + Blob() + + def test_source_property_and_repr_include_path(self, tmp_path): + file_path = tmp_path / "sample.txt" + file_path.write_text("hello", encoding="utf-8") + + blob = Blob.from_path(str(file_path)) + + assert blob.source == str(file_path) + assert str(file_path) in repr(blob) + + def test_as_string_from_bytes_and_str(self): + assert Blob.from_data(b"abc").as_string() == "abc" + assert Blob.from_data("plain-text").as_string() == "plain-text" + + def test_as_string_from_path(self, tmp_path): + file_path = tmp_path / "sample.txt" + file_path.write_text("from-file", encoding="utf-8") + + blob = Blob.from_path(str(file_path)) + + assert blob.as_string() == "from-file" + + def test_as_string_raises_for_invalid_state(self): + blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8") + + with pytest.raises(ValueError, match="Unable to get string for blob"): + blob.as_string() + + def test_as_bytes_from_bytes_str_and_path(self, tmp_path): + from_bytes = Blob.from_data(b"abc") + from_str = Blob.from_data("abc", encoding="utf-8") + + file_path = tmp_path / "sample.bin" + file_path.write_bytes(b"from-path") + from_path = Blob.from_path(str(file_path)) + + assert from_bytes.as_bytes() == b"abc" + assert from_str.as_bytes() == b"abc" + assert from_path.as_bytes() == b"from-path" + + def test_as_bytes_raises_for_invalid_state(self): + blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8") + + with pytest.raises(ValueError, match="Unable to get bytes for blob"): + blob.as_bytes() + + def test_as_bytes_io_for_bytes_and_path(self, tmp_path): + data_blob = Blob.from_data(b"bytes-io") + with data_blob.as_bytes_io() as stream: + assert isinstance(stream, BytesIO) + assert stream.read() == b"bytes-io" + + file_path = tmp_path / "stream.bin" + file_path.write_bytes(b"path-stream") + path_blob = Blob.from_path(str(file_path)) + with path_blob.as_bytes_io() as stream: + assert stream.read() == b"path-stream" + + def test_as_bytes_io_raises_for_unsupported_data_type(self): + blob = Blob.from_data("text-value") + + with pytest.raises(NotImplementedError, match="Unable to convert blob"): + with blob.as_bytes_io(): + pass + + def test_from_path_respects_guessing_and_explicit_mime(self, tmp_path): + file_path = tmp_path / "example.txt" + file_path.write_text("x", encoding="utf-8") + + guessed = Blob.from_path(str(file_path)) + explicit = Blob.from_path(str(file_path), mime_type="custom/type", guess_type=False) + + assert guessed.mimetype == "text/plain" + assert explicit.mimetype == "custom/type" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 4ee04ddebc..d3040395be 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -1,61 +1,337 @@ -import os +"""Unit tests for Firecrawl app and extractor integration points.""" + +import json +from collections.abc import Mapping +from typing import Any from unittest.mock import MagicMock import pytest from pytest_mock import MockerFixture +import core.rag.extractor.firecrawl.firecrawl_app as firecrawl_module from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp -from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response +from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor -def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture): - url = "https://firecrawl.dev" - api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" - base_url = "https://api.firecrawl.dev" - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url) - params = { - "includePaths": [], - "excludePaths": [], - "maxDepth": 1, - "limit": 1, - } - mocked_firecrawl = { - "id": "test", - } - mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl)) - job_id = firecrawl_app.crawl_url(url, params) - - assert job_id is not None - assert isinstance(job_id, str) +def _response(status_code: int, json_data: Mapping[str, Any] | None = None, text: str = "") -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.text = text + response.json.return_value = json_data if json_data is not None else {} + return response -def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture): - api_key = "fc-" - base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"] - for base in base_urls: - app = FirecrawlApp(api_key=api_key, base_url=base) - mock_post = mocker.patch("httpx.post") - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.json.return_value = {"id": "job123"} - mock_post.return_value = mock_resp - app.crawl_url("https://example.com", params=None) - called_url = mock_post.call_args[0][0] - assert called_url == "https://custom.firecrawl.dev/v2/crawl" +class TestFirecrawlApp: + def test_init_requires_api_key_for_default_base_url(self): + with pytest.raises(ValueError, match="No API key provided"): + FirecrawlApp(api_key=None, base_url="https://api.firecrawl.dev") + + def test_prepare_headers_and_build_url(self): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev/") + + assert app._prepare_headers() == { + "Content-Type": "application/json", + "Authorization": "Bearer fc-key", + } + assert app._build_url("/v2/crawl") == "https://custom.firecrawl.dev/v2/crawl" + + def test_scrape_url_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch( + "httpx.post", + return_value=_response( + 200, + { + "data": { + "metadata": { + "title": "t", + "description": "d", + "sourceURL": "https://example.com", + }, + "markdown": "body", + } + }, + ), + ) + + result = app.scrape_url("https://example.com", params={"onlyMainContent": False}) + + assert result == { + "title": "t", + "description": "d", + "source_url": "https://example.com", + "markdown": "body", + } + + def test_scrape_url_handles_known_error_status(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("boom")) + mocker.patch("httpx.post", return_value=_response(429, {"error": "limit"})) + + with pytest.raises(Exception, match="boom"): + app.scrape_url("https://example.com") + + mock_handle.assert_called_once() + + def test_scrape_url_unknown_status_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(404, text="Not Found")) + + with pytest.raises(Exception, match="Failed to scrape URL. Status code: 404"): + app.scrape_url("https://example.com") + + def test_crawl_url_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"id": "job-1"})) + + assert app.crawl_url("https://example.com") == "job-1" + + def test_crawl_url_non_200_uses_error_handler(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl failed")) + mocker.patch("httpx.post", return_value=_response(500, {"error": "server"})) + + with pytest.raises(Exception, match="crawl failed"): + app.crawl_url("https://example.com") + + mock_handle.assert_called_once() + + def test_map_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"success": True, "links": ["a", "b"]})) + + assert app.map("https://example.com") == {"success": True, "links": ["a", "b"]} + + def test_map_known_error(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error") + mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"})) + + assert app.map("https://example.com") == {} + mock_handle.assert_called_once() + + def test_map_unknown_error_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(418, text="teapot")) + + with pytest.raises(Exception, match="Failed to start map job. Status code: 418"): + app.map("https://example.com") + + def test_check_crawl_status_completed_with_data(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + payload = { + "status": "completed", + "total": 2, + "completed": 2, + "data": [ + { + "metadata": {"title": "a", "description": "desc-a", "sourceURL": "https://a"}, + "markdown": "m-a", + }, + { + "metadata": {"title": "b", "description": "desc-b", "sourceURL": "https://b"}, + "markdown": "m-b", + }, + {"metadata": {"title": "skip"}}, + ], + } + mocker.patch("httpx.get", return_value=_response(200, payload)) + + save_calls: list[tuple[str, bytes]] = [] + delete_calls: list[str] = [] + + mock_storage = MagicMock() + mock_storage.exists.return_value = True + mock_storage.delete.side_effect = lambda key: delete_calls.append(key) + mock_storage.save.side_effect = lambda key, data: save_calls.append((key, data)) + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-42") + + assert result["status"] == "completed" + assert result["total"] == 2 + assert result["current"] == 2 + assert len(result["data"]) == 2 + assert delete_calls == ["website_files/job-42.txt"] + assert len(save_calls) == 1 + assert save_calls[0][0] == "website_files/job-42.txt" + + def test_check_crawl_status_completed_with_zero_total_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": 0, "data": []})) + + with pytest.raises(Exception, match="No page found"): + app.check_crawl_status("job-1") + + def test_check_crawl_status_non_completed(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + payload = {"status": "processing", "total": 5, "completed": 1, "data": []} + mocker.patch("httpx.get", return_value=_response(200, payload)) + + assert app.check_crawl_status("job-1") == { + "status": "processing", + "total": 5, + "current": 1, + "data": [], + } + + def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error") + mocker.patch("httpx.get", return_value=_response(500, {"error": "server"})) + + assert app.check_crawl_status("job-1") == {} + mock_handle.assert_called_once() + + def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + payload = { + "status": "completed", + "total": 1, + "completed": 1, + "data": [{"metadata": {"title": "a", "sourceURL": "https://a"}, "markdown": "m-a"}], + } + mocker.patch("httpx.get", return_value=_response(200, payload)) + + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mock_storage.save.side_effect = RuntimeError("save failed") + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + with pytest.raises(Exception, match="Error saving crawl data"): + app.check_crawl_status("job-err") + + def test_extract_common_fields_and_status_formatter(self): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + + fields = app._extract_common_fields( + {"metadata": {"title": "t", "description": "d", "sourceURL": "u"}, "markdown": "m"} + ) + assert fields == {"title": "t", "description": "d", "source_url": "u", "markdown": "m"} + + status = app._format_crawl_status_response("completed", {"total": 1, "completed": 1}, [fields]) + assert status == {"status": "completed", "total": 1, "current": 1, "data": [fields]} + + def test_post_and_get_request_retry_logic(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep") + + resp_502_a = _response(502) + resp_502_b = _response(502) + resp_200 = _response(200) + + mocker.patch("httpx.post", side_effect=[resp_502_a, resp_200]) + post_result = app._post_request("u", {"x": 1}, {"h": 1}, retries=3, backoff_factor=0.5) + assert post_result is resp_200 + + mocker.patch("httpx.get", side_effect=[resp_502_b, _response(200)]) + get_result = app._get_request("u", {"h": 1}, retries=3, backoff_factor=0.25) + assert get_result.status_code == 200 + + assert sleep_mock.call_count == 2 + + def test_post_and_get_request_return_last_502(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep") + + last_post = _response(502) + mocker.patch("httpx.post", side_effect=[_response(502), last_post]) + assert app._post_request("u", {}, {}, retries=2).status_code == 502 + + last_get = _response(502) + mocker.patch("httpx.get", side_effect=[_response(502), last_get]) + assert app._get_request("u", {}, retries=2).status_code == 502 + + assert sleep_mock.call_count == 4 + + def test_handle_error_with_json_and_plain_text(self): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + + json_error = _response(400, {"message": "bad request"}) + with pytest.raises(Exception, match="bad request"): + app._handle_error(json_error, "run task") + + non_json = MagicMock() + non_json.status_code = 400 + non_json.text = "plain error" + non_json.json.side_effect = json.JSONDecodeError("bad", "x", 0) + + with pytest.raises(Exception, match="plain error"): + app._handle_error(non_json, "run task") + + def test_search_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"success": True, "data": [{"url": "x"}]})) + assert app.search("python")["success"] is True + + def test_search_warning_failure(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"success": False, "warning": "bad search"})) + with pytest.raises(Exception, match="bad search"): + app.search("python") + + def test_search_known_http_error(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error") + mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"})) + assert app.search("python") == {} + mock_handle.assert_called_once() + + def test_search_unknown_http_error(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(418, text="teapot")) + with pytest.raises(Exception, match="Failed to perform search. Status code: 418"): + app.search("python") -def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture): - api_key = "fc-" - app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/") - mock_post = mocker.patch("httpx.post") - mock_resp = MagicMock() - mock_resp.status_code = 404 - mock_resp.text = "Not Found" - mock_resp.json.side_effect = Exception("Not JSON") - mock_post.return_value = mock_resp +class TestFirecrawlWebExtractor: + def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data", + return_value={ + "markdown": "crawl content", + "source_url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) - with pytest.raises(Exception) as excinfo: - app.scrape_url("https://example.com") + extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + docs = extractor.extract() - # Should not raise a JSONDecodeError; current behavior reports status code only - assert str(excinfo.value) == "Failed to scrape URL. Status code: 404" + assert len(docs) == 1 + assert docs[0].page_content == "crawl content" + assert docs[0].metadata["source_url"] == "https://example.com" + + def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data", + return_value=None, + ) + + extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + assert extractor.extract() == [] + + def test_extract_scrape_mode_returns_document(self, mocker: MockerFixture): + mock_scrape = mocker.patch( + "core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_scrape_url_data", + return_value={ + "markdown": "scrape content", + "source_url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) + + extractor = FirecrawlWebExtractor( + "https://example.com", "job-1", "tenant-1", mode="scrape", only_main_content=False + ) + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "scrape content" + mock_scrape.assert_called_once_with("firecrawl", "https://example.com", "tenant-1", False) + + def test_extract_unknown_mode_returns_empty(self): + extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="unknown") + assert extractor.extract() == [] diff --git a/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py new file mode 100644 index 0000000000..e6a06f163e --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py @@ -0,0 +1,95 @@ +import csv +import io +from types import SimpleNamespace + +import pandas as pd +import pytest + +import core.rag.extractor.csv_extractor as csv_module +from core.rag.extractor.csv_extractor import CSVExtractor + + +class _ManagedStringIO(io.StringIO): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + return False + + +class TestCSVExtractor: + def test_extract_success_with_source_column(self, tmp_path): + file_path = tmp_path / "data.csv" + file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8") + + extractor = CSVExtractor(str(file_path), source_column="id") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "id: source-1;body: hello" + assert docs[0].metadata == {"source": "source-1", "row": 0} + + def test_extract_raises_when_source_column_missing(self, tmp_path): + file_path = tmp_path / "data.csv" + file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8") + + extractor = CSVExtractor(str(file_path), source_column="missing_col") + + with pytest.raises(ValueError, match="Source column 'missing_col' not found"): + extractor.extract() + + def test_extract_wraps_unicode_error_when_autodetect_disabled(self, monkeypatch): + extractor = CSVExtractor("dummy.csv", autodetect_encoding=False) + + def raise_decode(*args, **kwargs): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + + monkeypatch.setattr("builtins.open", raise_decode) + + with pytest.raises(RuntimeError, match="Error loading dummy.csv"): + extractor.extract() + + def test_extract_autodetect_encoding_success(self, monkeypatch): + extractor = CSVExtractor("dummy.csv", autodetect_encoding=True) + attempted_encodings: list[str | None] = [] + + def fake_open(path, newline="", encoding=None): + attempted_encodings.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + if encoding == "bad": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + return _ManagedStringIO("id,body\nsource-1,hello\n") + + monkeypatch.setattr("builtins.open", fake_open) + monkeypatch.setattr( + csv_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")], + ) + + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "id: source-1;body: hello" + assert attempted_encodings == [None, "bad", "utf-8"] + + def test_extract_autodetect_encoding_all_attempts_fail_returns_empty(self, monkeypatch): + extractor = CSVExtractor("dummy.csv", autodetect_encoding=True) + + def always_raise(*args, **kwargs): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + + monkeypatch.setattr("builtins.open", always_raise) + monkeypatch.setattr(csv_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="bad")]) + + assert extractor.extract() == [] + + def test_read_from_file_re_raises_csv_error(self, monkeypatch): + extractor = CSVExtractor("dummy.csv") + + monkeypatch.setattr(pd, "read_csv", lambda *args, **kwargs: (_ for _ in ()).throw(csv.Error("bad csv"))) + + with pytest.raises(csv.Error, match="bad csv"): + extractor._read_from_file(io.StringIO("x")) diff --git a/api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py new file mode 100644 index 0000000000..d2bcc1e2c4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py @@ -0,0 +1,117 @@ +from types import SimpleNamespace + +import pandas as pd +import pytest + +import core.rag.extractor.excel_extractor as excel_module +from core.rag.extractor.excel_extractor import ExcelExtractor + + +class _FakeCell: + def __init__(self, value, hyperlink=None): + self.value = value + self.hyperlink = hyperlink + + +class _FakeSheet: + def __init__(self, header_rows, data_rows): + self._header_rows = header_rows + self._data_rows = data_rows + + def iter_rows(self, min_row=1, max_row=None, max_col=None, values_only=False): + if values_only: + for row in self._header_rows: + yield tuple(row) + return + + for row in self._data_rows: + if max_col is not None: + yield tuple(row[:max_col]) + else: + yield tuple(row) + + +class _FakeWorkbook: + def __init__(self, sheets): + self._sheets = sheets + self.sheetnames = list(sheets.keys()) + self.closed = False + + def __getitem__(self, key): + return self._sheets[key] + + def close(self): + self.closed = True + + +class TestExcelExtractor: + def test_extract_xlsx_with_hyperlinks_and_sheet_skip(self, monkeypatch): + sheet_with_data = _FakeSheet( + header_rows=[("Name", "Link")], + data_rows=[ + (_FakeCell("Alice"), _FakeCell("Doc", hyperlink=SimpleNamespace(target="https://example.com/doc"))), + (_FakeCell(None), _FakeCell(123)), + (_FakeCell(None), _FakeCell(None)), + ], + ) + empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[]) + + workbook = _FakeWorkbook({"Data": sheet_with_data, "Empty": empty_sheet}) + monkeypatch.setattr(excel_module, "load_workbook", lambda *args, **kwargs: workbook) + + extractor = ExcelExtractor("/tmp/sample.xlsx") + docs = extractor.extract() + + assert workbook.closed is True + assert len(docs) == 2 + assert docs[0].page_content == '"Name":"Alice";"Link":"[Doc](https://example.com/doc)"' + assert docs[1].page_content == '"Name":"";"Link":"123"' + assert all(doc.metadata["source"] == "/tmp/sample.xlsx" for doc in docs) + + def test_extract_xls_path(self, monkeypatch): + class FakeExcelFile: + sheet_names = ["Sheet1"] + + def parse(self, sheet_name): + assert sheet_name == "Sheet1" + return pd.DataFrame([{"A": "x", "B": 1}, {"A": None, "B": None}]) + + monkeypatch.setattr(pd, "ExcelFile", lambda path, engine=None: FakeExcelFile()) + + extractor = ExcelExtractor("/tmp/sample.xls") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == '"A":"x";"B":"1.0"' + assert docs[0].metadata == {"source": "/tmp/sample.xls"} + + def test_extract_unsupported_extension_raises(self): + extractor = ExcelExtractor("/tmp/sample.txt") + + with pytest.raises(ValueError, match="Unsupported file extension"): + extractor.extract() + + def test_find_header_and_columns_prefers_first_row_with_two_columns(self): + sheet = _FakeSheet( + header_rows=[(None, None, None), ("A", "B", None), ("X", None, None)], + data_rows=[], + ) + extractor = ExcelExtractor("dummy.xlsx") + + header_row_idx, column_map, max_col_idx = extractor._find_header_and_columns(sheet) + + assert header_row_idx == 2 + assert column_map == {0: "A", 1: "B"} + assert max_col_idx == 2 + + def test_find_header_and_columns_fallback_and_empty_case(self): + extractor = ExcelExtractor("dummy.xlsx") + + fallback_sheet = _FakeSheet(header_rows=[("Only", None), (None, "Second")], data_rows=[]) + row_idx, column_map, max_col_idx = extractor._find_header_and_columns(fallback_sheet) + assert row_idx == 1 + assert column_map == {0: "Only"} + assert max_col_idx == 1 + + empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[]) + assert extractor._find_header_and_columns(empty_sheet) == (0, {}, 0) diff --git a/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py b/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py new file mode 100644 index 0000000000..5beed88971 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py @@ -0,0 +1,272 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.extract_processor as processor_module +from core.rag.extractor.entity.datasource_type import DatasourceType +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.models.document import Document + + +class _ExtractorFactory: + def __init__(self) -> None: + self.calls = [] + + def make(self, name: str) -> type[object]: + calls = self.calls + + class DummyExtractor: + def __init__(self, *args, **kwargs): + calls.append((name, args, kwargs)) + + def extract(self): + return [Document(page_content=f"extracted-by-{name}")] + + return DummyExtractor + + +def _patch_all_extractors(monkeypatch) -> _ExtractorFactory: + factory = _ExtractorFactory() + + for cls_name in [ + "CSVExtractor", + "ExcelExtractor", + "FirecrawlWebExtractor", + "HtmlExtractor", + "JinaReaderWebExtractor", + "MarkdownExtractor", + "NotionExtractor", + "PdfExtractor", + "TextExtractor", + "UnstructuredEmailExtractor", + "UnstructuredEpubExtractor", + "UnstructuredMarkdownExtractor", + "UnstructuredMsgExtractor", + "UnstructuredPPTExtractor", + "UnstructuredPPTXExtractor", + "UnstructuredWordExtractor", + "UnstructuredXmlExtractor", + "WaterCrawlWebExtractor", + "WordExtractor", + ]: + monkeypatch.setattr(processor_module, cls_name, factory.make(cls_name)) + + return factory + + +class TestExtractProcessorLoaders: + def test_load_from_upload_file_return_docs_and_text(self, monkeypatch): + monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs)) + + monkeypatch.setattr( + ExtractProcessor, + "extract", + lambda extract_setting, is_automatic=False, file_path=None: [ + Document(page_content="doc-1"), + Document(page_content="doc-2"), + ], + ) + + upload_file = SimpleNamespace(key="file.txt") + + docs = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=False) + text = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=True) + + assert len(docs) == 2 + assert text == "doc-1\ndoc-2" + + @pytest.mark.parametrize( + ("url", "headers", "expected_suffix"), + [ + ("https://example.com/file.txt", {"Content-Type": "text/plain"}, ".txt"), + ("https://example.com/no_suffix", {"Content-Type": "application/pdf"}, ".pdf"), + ( + "https://example.com/no_suffix", + {"Content-Disposition": 'attachment; filename="report.md"'}, + ".md", + ), + ( + "https://example.com/no_suffix", + {"Content-Disposition": 'attachment; filename="report"'}, + "", + ), + ], + ) + def test_load_from_url_builds_temp_file_with_correct_suffix(self, monkeypatch, url, headers, expected_suffix): + response = SimpleNamespace(headers=headers, content=b"body") + monkeypatch.setattr(processor_module.ssrf_proxy, "get", lambda *args, **kwargs: response) + monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs)) + + captured = {} + + def fake_extract(extract_setting, is_automatic=False, file_path=None): + key = "file_path_docs" if "file_path_docs" not in captured else "file_path_text" + captured[key] = file_path + return [Document(page_content="u1"), Document(page_content="u2")] + + monkeypatch.setattr(ExtractProcessor, "extract", fake_extract) + + docs = ExtractProcessor.load_from_url(url, return_text=False) + assert captured["file_path_docs"].endswith(expected_suffix) + + text = ExtractProcessor.load_from_url(url, return_text=True) + assert captured["file_path_text"].endswith(expected_suffix) + + assert len(docs) == 2 + assert text == "u1\nu2" + + +class TestExtractProcessorFileRouting: + @pytest.fixture(autouse=True) + def _set_unstructured_config(self, monkeypatch): + monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_URL", "https://unstructured") + monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_KEY", "key") + + def _run_extract_for_extension(self, monkeypatch, extension: str, etl_type: str, is_automatic: bool = False): + factory = _patch_all_extractors(monkeypatch) + monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", etl_type) + + def fake_download(key: str, local_path: str): + Path(local_path).write_text("content", encoding="utf-8") + + monkeypatch.setattr(processor_module.storage, "download", fake_download) + monkeypatch.setattr(processor_module.tempfile, "_get_candidate_names", lambda: iter(["candidate-name"])) + + setting = SimpleNamespace( + datasource_type=DatasourceType.FILE, + upload_file=SimpleNamespace(key=f"uploaded{extension}", tenant_id="tenant-1", created_by="user-1"), + ) + + docs = ExtractProcessor.extract(setting, is_automatic=is_automatic) + + assert len(docs) == 1 + assert docs[0].page_content.startswith("extracted-by-") + return factory.calls[-1][0], factory.calls[-1][1], factory.calls[-1][2] + + @pytest.mark.parametrize( + ("extension", "expected_extractor", "is_automatic"), + [ + (".xlsx", "ExcelExtractor", False), + (".xls", "ExcelExtractor", False), + (".pdf", "PdfExtractor", False), + (".md", "UnstructuredMarkdownExtractor", True), + (".mdx", "MarkdownExtractor", False), + (".htm", "HtmlExtractor", False), + (".html", "HtmlExtractor", False), + (".docx", "WordExtractor", False), + (".doc", "UnstructuredWordExtractor", False), + (".csv", "CSVExtractor", False), + (".msg", "UnstructuredMsgExtractor", False), + (".eml", "UnstructuredEmailExtractor", False), + (".ppt", "UnstructuredPPTExtractor", False), + (".pptx", "UnstructuredPPTXExtractor", False), + (".xml", "UnstructuredXmlExtractor", False), + (".epub", "UnstructuredEpubExtractor", False), + (".txt", "TextExtractor", False), + ], + ) + def test_extract_routes_file_extensions_for_unstructured_mode( + self, monkeypatch, extension, expected_extractor, is_automatic + ): + extractor_name, args, kwargs = self._run_extract_for_extension( + monkeypatch, extension, etl_type="Unstructured", is_automatic=is_automatic + ) + + assert extractor_name == expected_extractor + assert args + + @pytest.mark.parametrize( + ("extension", "expected_extractor"), + [ + (".xlsx", "ExcelExtractor"), + (".pdf", "PdfExtractor"), + (".markdown", "MarkdownExtractor"), + (".html", "HtmlExtractor"), + (".docx", "WordExtractor"), + (".csv", "CSVExtractor"), + (".epub", "UnstructuredEpubExtractor"), + (".txt", "TextExtractor"), + ], + ) + def test_extract_routes_file_extensions_for_default_mode(self, monkeypatch, extension, expected_extractor): + extractor_name, _, _ = self._run_extract_for_extension(monkeypatch, extension, etl_type="SelfHosted") + + assert extractor_name == expected_extractor + + def test_extract_requires_upload_file_when_file_path_not_provided(self): + setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None) + + with pytest.raises(AssertionError, match="upload_file is required"): + ExtractProcessor.extract(setting) + + +class TestExtractProcessorDatasourceRouting: + def test_extract_routes_notion_datasource(self, monkeypatch): + factory = _patch_all_extractors(monkeypatch) + + notion_info = SimpleNamespace( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + document="doc", + tenant_id="tenant", + credential_id="cred", + ) + setting = SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=notion_info) + + docs = ExtractProcessor.extract(setting) + + assert docs[0].page_content == "extracted-by-NotionExtractor" + assert factory.calls[-1][0] == "NotionExtractor" + + @pytest.mark.parametrize( + ("provider", "expected"), + [ + ("firecrawl", "FirecrawlWebExtractor"), + ("watercrawl", "WaterCrawlWebExtractor"), + ("jinareader", "JinaReaderWebExtractor"), + ], + ) + def test_extract_routes_website_datasource_providers(self, monkeypatch, provider: str, expected: str): + factory = _patch_all_extractors(monkeypatch) + + website_info = SimpleNamespace( + provider=provider, + url="https://example.com", + job_id="job", + tenant_id="tenant", + mode="crawl", + only_main_content=True, + ) + setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=website_info) + + docs = ExtractProcessor.extract(setting) + assert docs[0].page_content == f"extracted-by-{expected}" + assert factory.calls[-1][0] == expected + + def test_extract_unsupported_website_provider(self): + bad_provider = SimpleNamespace( + provider="unknown", + url="https://example.com", + job_id="job", + tenant_id="tenant", + mode="crawl", + only_main_content=True, + ) + setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=bad_provider) + + with pytest.raises(ValueError, match="Unsupported website provider"): + ExtractProcessor.extract(setting) + + def test_extract_unsupported_datasource_type(self): + with pytest.raises(ValueError, match="Unsupported datasource type"): + ExtractProcessor.extract(SimpleNamespace(datasource_type="unknown")) + + def test_extract_requires_notion_info(self): + with pytest.raises(AssertionError, match="notion_info is required"): + ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=None)) + + def test_extract_requires_website_info(self): + with pytest.raises(AssertionError, match="website_info is required"): + ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=None)) diff --git a/api/tests/unit_tests/core/rag/extractor/test_extractor_base.py b/api/tests/unit_tests/core/rag/extractor/test_extractor_base.py new file mode 100644 index 0000000000..1d5f27181b --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_extractor_base.py @@ -0,0 +1,26 @@ +import pytest + +from core.rag.extractor.extractor_base import BaseExtractor + + +class _CallsBaseExtractor(BaseExtractor): + def extract(self): + return super().extract() + + +class _ConcreteExtractor(BaseExtractor): + def extract(self): + return ["ok"] + + +class TestBaseExtractor: + def test_extract_default_raises_not_implemented(self): + extractor = _CallsBaseExtractor() + + with pytest.raises(NotImplementedError): + extractor.extract() + + def test_concrete_extractor_can_override(self): + extractor = _ConcreteExtractor() + + assert extractor.extract() == ["ok"] diff --git a/api/tests/unit_tests/core/rag/extractor/test_helpers.py b/api/tests/unit_tests/core/rag/extractor/test_helpers.py index edf8735e57..74387f749d 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_helpers.py +++ b/api/tests/unit_tests/core/rag/extractor/test_helpers.py @@ -1,10 +1,55 @@ import tempfile +from types import SimpleNamespace -from core.rag.extractor.helpers import FileEncoding, detect_file_encodings +import pytest + +from core.rag.extractor import helpers +from core.rag.extractor.helpers import detect_file_encodings -def test_detect_file_encodings() -> None: - with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp: - temp.write("Shared data") - temp_path = temp.name - assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")] +class TestHelpers: + def test_detect_file_encodings(self) -> None: + with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp: + temp.write("Shared data") + temp.flush() + temp_path = temp.name + encodings = detect_file_encodings(temp_path) + + assert len(encodings) == 1 + assert encodings[0].encoding in {"utf_8", "ascii"} + assert encodings[0].confidence == 0.0 + # Assert the language field for full coverage + assert encodings[0].language is not None + + def test_detect_file_encodings_timeout(self, monkeypatch): + class FakeFuture: + def result(self, timeout=None): + raise helpers.concurrent.futures.TimeoutError() + + class FakeExecutor: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def submit(self, fn, file_path): + return FakeFuture() + + monkeypatch.setattr(helpers.concurrent.futures, "ThreadPoolExecutor", lambda: FakeExecutor()) + + with pytest.raises(TimeoutError, match="Timeout reached while detecting encoding"): + detect_file_encodings("file.txt", timeout=1) + + def test_detect_file_encodings_raises_when_encoding_not_detected(self, monkeypatch): + class FakeResult: + encoding = None + coherence = 0.0 + language = None + + monkeypatch.setattr( + helpers.charset_normalizer, "from_path", lambda _: SimpleNamespace(best=lambda: FakeResult()) + ) + + with pytest.raises(RuntimeError, match="Could not detect encoding"): + detect_file_encodings("file.txt") diff --git a/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py new file mode 100644 index 0000000000..8bc65e5654 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py @@ -0,0 +1,21 @@ +from core.rag.extractor.html_extractor import HtmlExtractor + + +class TestHtmlExtractor: + def test_extract_returns_text_content(self, tmp_path): + file_path = tmp_path / "sample.html" + file_path.write_text("

Title

Hello

", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert "".join(docs[0].page_content.split()) == "TitleHello" + + def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path): + file_path = tmp_path / "sample.html" + file_path.write_text(" \n ", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + + assert extractor._load_as_text() == "" diff --git a/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py new file mode 100644 index 0000000000..0b4c9bd809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py @@ -0,0 +1,47 @@ +from pytest_mock import MockerFixture + +from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor + + +class TestJinaReaderWebExtractor: + def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={ + "content": "markdown-content", + "url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "markdown-content" + assert docs[0].metadata == { + "source_url": "https://example.com", + "description": "desc", + "title": "title", + } + + def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value=None, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_non_crawl_mode_returns_empty(self, mocker: MockerFixture): + mock_get_crawl = mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={"content": "unused"}, + ) + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert extractor.extract() == [] + mock_get_crawl.assert_not_called() diff --git a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py index d4cf534c56..7e78c86c7d 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py @@ -1,8 +1,15 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.markdown_extractor as markdown_module from core.rag.extractor.markdown_extractor import MarkdownExtractor -def test_markdown_to_tups(): - markdown = """ +class TestMarkdownExtractor: + def test_markdown_to_tups(self): + markdown = """ this is some text without header # title 1 @@ -11,12 +18,113 @@ this is balabala text ## title 2 this is more specific text. """ - extractor = MarkdownExtractor(file_path="dummy_path") - updated_output = extractor.markdown_to_tups(markdown) - assert len(updated_output) == 3 - key, header_value = updated_output[0] - assert key == None - assert header_value.strip() == "this is some text without header" - title_1, value = updated_output[1] - assert title_1.strip() == "title 1" - assert value.strip() == "this is balabala text" + extractor = MarkdownExtractor(file_path="dummy_path") + updated_output = extractor.markdown_to_tups(markdown) + + assert len(updated_output) == 3 + key, header_value = updated_output[0] + assert key is None + assert header_value.strip() == "this is some text without header" + + title_1, value = updated_output[1] + assert title_1.strip() == "title 1" + assert value.strip() == "this is balabala text" + + def test_markdown_to_tups_keeps_code_block_headers_literal(self): + markdown = """# Header +before +```python +# this is not a heading +print('x') +``` +after +""" + extractor = MarkdownExtractor(file_path="dummy_path") + + tups = extractor.markdown_to_tups(markdown) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "# this is not a heading" in tups[1][1] + + def test_remove_images_and_hyperlinks(self): + extractor = MarkdownExtractor(file_path="dummy_path") + + with_images = "before ![[image.png]] after" + with_links = "[OpenAI](https://openai.com)" + + assert extractor.remove_images(with_images) == "before after" + assert extractor.remove_hyperlinks(with_links) == "OpenAI" + + def test_parse_tups_reads_file_and_applies_options(self, tmp_path): + markdown_file = tmp_path / "doc.md" + markdown_file.write_text("# Header\nText with [link](https://example.com) and ![[img.png]]", encoding="utf-8") + + extractor = MarkdownExtractor( + file_path=str(markdown_file), + remove_hyperlinks=True, + remove_images=True, + autodetect_encoding=False, + ) + + tups = extractor.parse_tups(str(markdown_file)) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "[link]" not in tups[1][1] + assert "img.png" not in tups[1][1] + + def test_parse_tups_autodetects_encoding_after_decode_error(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=True) + + calls: list[str | None] = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + if encoding == "bad-encoding": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + return "# H\ncontent" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + markdown_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad-encoding"), SimpleNamespace(encoding="utf-8")], + ) + + tups = extractor.parse_tups("dummy_path") + + assert len(tups) == 2 + assert calls == [None, "bad-encoding", "utf-8"] + + def test_parse_tups_decode_error_with_autodetect_disabled_raises(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=False) + + def raise_decode(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + + monkeypatch.setattr(Path, "read_text", raise_decode, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_parse_tups_other_exceptions_are_wrapped(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + + def raise_other(self, encoding=None): + raise OSError("disk error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_extract_builds_documents_for_header_and_non_header(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + monkeypatch.setattr(extractor, "parse_tups", lambda _: [(None, "plain"), ("Header", "value")]) + + docs = extractor.extract() + + assert [doc.page_content for doc in docs] == ["plain", "\n\nHeader\nvalue"] diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index 58bec7d19e..6daee11f8f 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -1,93 +1,499 @@ +from types import SimpleNamespace from unittest import mock +import httpx +import pytest from pytest_mock import MockerFixture from core.rag.extractor import notion_extractor -user_id = "user1" -database_id = "database1" -page_id = "page1" - -extractor = notion_extractor.NotionExtractor( - notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" -) - - -def _generate_page(page_title: str): - return { - "object": "page", - "id": page_id, - "properties": { - "Page": { - "type": "title", - "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], - } - }, - } - - -def _generate_block(block_id: str, block_type: str, block_text: str): - return { - "object": "block", - "id": block_id, - "parent": {"type": "page_id", "page_id": page_id}, - "type": block_type, - "has_children": False, - block_type: { - "rich_text": [ - { - "type": "text", - "text": {"content": block_text}, - "plain_text": block_text, - } - ] - }, - } - - -def _mock_response(data): +def _mock_response(data, status_code: int = 200, text: str = ""): response = mock.Mock() - response.status_code = 200 + response.status_code = status_code + response.text = text response.json.return_value = data return response -def _remove_multiple_new_lines(text): - while "\n\n" in text: - text = text.replace("\n\n", "\n") - return text.strip() +class TestNotionExtractorInitAndPublicMethods: + def test_init_with_explicit_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor._notion_access_token == "token" + + def test_init_falls_back_to_env_token_when_credential_lookup_fails(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", "env-token", raising=False) + + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + assert extractor._notion_access_token == "env-token" + + def test_init_raises_if_no_credential_and_no_env_token(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", None, raising=False) + + with pytest.raises(ValueError, match="Must specify `integration_token`"): + notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + def test_extract_updates_last_edited_and_loads_documents(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + update_mock = mock.Mock() + load_mock = mock.Mock(return_value=[SimpleNamespace(page_content="doc")]) + monkeypatch.setattr(extractor, "update_last_edited_time", update_mock) + monkeypatch.setattr(extractor, "_load_data_as_documents", load_mock) + + docs = extractor.extract() + + update_mock.assert_called_once_with(None) + load_mock.assert_called_once_with("obj", "page") + assert len(docs) == 1 + + def test_load_data_as_documents_page_database_and_invalid(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + monkeypatch.setattr(extractor, "_get_notion_block_data", lambda _: ["line1", "line2"]) + page_docs = extractor._load_data_as_documents("page-id", "page") + assert page_docs[0].page_content == "line1\nline2" + + monkeypatch.setattr(extractor, "_get_notion_database_data", lambda _: [SimpleNamespace(page_content="db")]) + db_docs = extractor._load_data_as_documents("db-id", "database") + assert db_docs[0].page_content == "db" + + with pytest.raises(ValueError, match="notion page type not supported"): + extractor._load_data_as_documents("obj", "unsupported") -def test_notion_page(mocker: MockerFixture): - texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] - mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]), - ], - "next_cursor": None, - } - mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page)) +class TestNotionDatabase: + def test_get_notion_database_data_parses_property_types_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) - page_docs = extractor._load_data_as_documents(page_id, "page") - assert len(page_docs) == 1 - content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" + first_page = { + "results": [ + { + "properties": { + "tags": { + "type": "multi_select", + "multi_select": [{"name": "A"}, {"name": "B"}], + }, + "title_prop": {"type": "title", "title": [{"plain_text": "Title"}]}, + "empty_title": {"type": "title", "title": []}, + "rich": {"type": "rich_text", "rich_text": [{"plain_text": "RichText"}]}, + "empty_rich": {"type": "rich_text", "rich_text": []}, + "select_prop": {"type": "select", "select": {"name": "Selected"}}, + "empty_select": {"type": "select", "select": None}, + "status_prop": {"type": "status", "status": {"name": "Open"}}, + "empty_status": {"type": "status", "status": None}, + "number_prop": {"type": "number", "number": 10}, + "dict_prop": {"type": "date", "date": {"start": "2024-01-01", "end": None}}, + }, + "url": "https://notion.so/page-1", + } + ], + "has_more": True, + "next_cursor": "cursor-2", + } + second_page = {"results": [], "has_more": False, "next_cursor": None} + + mock_post = mocker.patch("httpx.post", side_effect=[_mock_response(first_page), _mock_response(second_page)]) + + docs = extractor._get_notion_database_data("db-1", query_dict={"filter": {"x": 1}}) + + assert len(docs) == 1 + content = docs[0].page_content + assert "tags:['A', 'B']" in content + assert "title_prop:Title" in content + assert "rich:RichText" in content + assert "number_prop:10" in content + assert "dict_prop:start:2024-01-01" in content + assert "Row Page URL:https://notion.so/page-1" in content + assert mock_post.call_count == 2 + + def test_get_notion_database_data_handles_missing_results_and_empty_content(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.post", return_value=_mock_response({"results": None})) + assert extractor._get_notion_database_data("db-1") == [] + + def test_get_notion_database_data_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor._get_notion_database_data("db-1") -def test_notion_database(mocker: MockerFixture): - page_title_list = ["page1", "page2", "page3"] - mocked_notion_database = { - "object": "list", - "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None, - } - mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database)) - database_docs = extractor._load_data_as_documents(database_id, "database") - assert len(database_docs) == 1 - content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == "\n".join([f"Page:{i}" for i in page_title_list]) +class TestNotionBlocks: + def test_get_notion_block_data_success_with_table_headings_children_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + first_response = { + "results": [ + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + { + "type": "heading_1", + "id": "h1", + "has_children": False, + "heading_1": {"rich_text": [{"text": {"content": "Heading"}}]}, + }, + { + "type": "paragraph", + "id": "p1", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Paragraph"}}]}, + }, + { + "type": "child_page", + "id": "cp1", + "has_children": True, + "child_page": {"rich_text": []}, + }, + ], + "next_cursor": "cursor-2", + } + second_response = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "SubHeading"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(first_response), _mock_response(second_response)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE") + mocker.patch.object(extractor, "_read_block", return_value="CHILD") + + lines = extractor._get_notion_block_data("page-1") + + assert lines[0] == "TABLE\n\n" + assert "# Heading" in lines[1] + assert "Paragraph\nCHILD\n\n" in lines[2] + assert "## SubHeading" in lines[-1] + + def test_get_notion_block_data_handles_http_error_and_invalid_payload(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.request", side_effect=httpx.HTTPError("network")) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"bad": "payload"}, status_code=200)) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"results": []}, status_code=500, text="boom")) + with pytest.raises(ValueError, match="Error fetching Notion block data: boom"): + extractor._get_notion_block_data("page-1") + + def test_read_block_supports_heading_table_and_recursion(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + root_payload = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "Root"}}]}, + }, + { + "type": "paragraph", + "id": "child-block", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Parent"}}]}, + }, + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + ], + "next_cursor": None, + } + child_payload = { + "results": [ + { + "type": "paragraph", + "id": "leaf", + "has_children": False, + "paragraph": {"rich_text": [{"text": {"content": "Child"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(root_payload), _mock_response(child_payload)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE-MD") + + content = extractor._read_block("root") + + assert "## Root" in content + assert "Parent" in content + assert "Child" in content + assert "TABLE-MD" in content + + def test_read_block_breaks_on_missing_results(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + mocker.patch("httpx.request", return_value=_mock_response({"results": None, "next_cursor": None})) + + assert extractor._read_block("root") == "" + + def test_read_table_rows_formats_markdown_with_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + page_one = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [{"text": {"content": "H2"}}], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R1C1"}}], + [{"text": {"content": "R1C2"}}], + ] + } + }, + ], + "next_cursor": "next", + } + page_two = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R2C1"}}], + [{"text": {"content": "R2C2"}}], + ] + } + }, + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(page_one), _mock_response(page_two)]) + + markdown = extractor._read_table_rows("tbl-1") + + assert "| H1 | H2 |" in markdown + assert "| R1C1 | R1C2 |" in markdown + assert "| H1 | |" in markdown + assert "| R2C1 | R2C2 |" in markdown + + +class TestNotionMetadataAndCredentialMethods: + def test_update_last_edited_time_no_document_model(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor.update_last_edited_time(None) is None + + def test_update_last_edited_time_updates_document_and_commits(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + class FakeDocumentModel: + data_source_info = "data_source_info" + + update_calls = [] + + class FakeQuery: + def filter_by(self, **kwargs): + return self + + def update(self, payload): + update_calls.append(payload) + + class FakeSession: + committed = False + + def query(self, model): + assert model is FakeDocumentModel + return FakeQuery() + + def commit(self): + self.committed = True + + fake_db = SimpleNamespace(session=FakeSession()) + monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel) + monkeypatch.setattr(notion_extractor, "db", fake_db) + monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z") + + doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"}) + extractor.update_last_edited_time(doc_model) + + assert update_calls + assert fake_db.session.committed is True + + def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture): + extractor_page = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="page-id", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-05-01T00:00:00.000Z"}) + ) + + assert extractor_page.get_notion_last_edited_time() == "2025-05-01T00:00:00.000Z" + assert "pages/page-id" in request_mock.call_args[0][1] + + extractor_db = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="db-id", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-06-01T00:00:00.000Z"}) + ) + + assert extractor_db.get_notion_last_edited_time() == "2025-06-01T00:00:00.000Z" + assert "databases/db-id" in request_mock.call_args[0][1] + + def test_get_notion_last_edited_time_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor.get_notion_last_edited_time() + + def test_get_access_token_success_and_errors(self, monkeypatch): + with pytest.raises(Exception, match="No credential id found"): + notion_extractor.NotionExtractor._get_access_token("tenant", None) + + class FakeProviderServiceMissing: + def get_datasource_credentials(self, **kwargs): + return None + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceMissing) + with pytest.raises(Exception, match="No notion credential found"): + notion_extractor.NotionExtractor._get_access_token("tenant", "cred") + + class FakeProviderServiceFound: + def get_datasource_credentials(self, **kwargs): + return {"integration_secret": "token-from-credential"} + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceFound) + + assert notion_extractor.NotionExtractor._get_access_token("tenant", "cred") == "token-from-credential" diff --git a/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py new file mode 100644 index 0000000000..fb3c6e52c6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.text_extractor as text_module +from core.rag.extractor.text_extractor import TextExtractor + + +class TestTextExtractor: + def test_extract_success(self, tmp_path): + file_path = tmp_path / "data.txt" + file_path.write_text("hello world", encoding="utf-8") + + extractor = TextExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "hello world" + assert docs[0].metadata == {"source": str(file_path)} + + def test_extract_autodetect_success_after_decode_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + calls = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + if encoding == "bad": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + return "decoded text" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + text_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")], + ) + + docs = extractor.extract() + + assert docs[0].page_content == "decoded text" + assert calls == [None, "bad", "utf-8"] + + def test_extract_autodetect_all_fail_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + monkeypatch.setattr(text_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="latin-1")]) + + with pytest.raises(RuntimeError, match="all detected encodings failed"): + extractor.extract() + + def test_extract_decode_error_without_autodetect_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=False) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + + with pytest.raises(RuntimeError, match="specified encoding failed"): + extractor.extract() + + def test_extract_wraps_non_decode_exceptions(self, monkeypatch): + extractor = TextExtractor("dummy.txt") + + def raise_other(self, encoding=None): + raise OSError("io error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy.txt"): + extractor.extract() diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 0792ada194..64eb89590a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -3,9 +3,12 @@ import io import os import tempfile +from collections import UserDict from pathlib import Path from types import SimpleNamespace +from unittest.mock import MagicMock +import pytest from docx import Document from docx.oxml import OxmlElement from docx.oxml.ns import qn @@ -136,7 +139,7 @@ def test_extract_images_from_docx(monkeypatch): monkeypatch.setattr(we, "UploadFile", FakeUploadFile) # Patch external image fetcher - def fake_get(url: str): + def fake_get(url: str, **kwargs): assert url == "https://example.com/image.png" return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes) @@ -203,10 +206,8 @@ def test_extract_images_from_docx_uses_internal_files_url(): finally: # Restore original values - if original_files_url is not None: - dify_config.FILES_URL = original_files_url - if original_internal_files_url is not None: - dify_config.INTERNAL_FILES_URL = original_internal_files_url + dify_config.FILES_URL = original_files_url + dify_config.INTERNAL_FILES_URL = original_internal_files_url def test_extract_hyperlinks(monkeypatch): @@ -314,3 +315,405 @@ def test_extract_legacy_hyperlinks(monkeypatch): finally: if os.path.exists(tmp_path): os.remove(tmp_path) + + +def test_init_rejects_invalid_url_status(monkeypatch): + class FakeResponse: + status_code = 404 + content = b"" + closed = False + + def close(self): + self.closed = True + + fake_response = FakeResponse() + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=lambda url, **kwargs: fake_response)) + + with pytest.raises(ValueError, match="returned status code 404"): + WordExtractor("https://example.com/missing.docx", "tenant", "user") + + assert fake_response.closed is True + + +def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): + target_file = tmp_path / "expanded.docx" + target_file.write_bytes(b"docx") + + monkeypatch.setattr(we.os.path, "expanduser", lambda p: str(target_file)) + monkeypatch.setattr( + we.os.path, + "isfile", + lambda p: p == str(target_file), + ) + + extractor = WordExtractor("~/expanded.docx", "tenant", "user") + assert extractor.file_path == str(target_file) + + monkeypatch.setattr(we.os.path, "isfile", lambda p: False) + with pytest.raises(ValueError, match="is not a valid file or url"): + WordExtractor("not-a-file", "tenant", "user") + + +def test_del_closes_temp_file(): + extractor = object.__new__(WordExtractor) + extractor.temp_file = MagicMock() + + WordExtractor.__del__(extractor) + + extractor.temp_file.close.assert_called_once() + + +def test_extract_images_handles_invalid_external_cases(monkeypatch): + class FakeTargetRef: + def __contains__(self, item): + return item == "image" + + def split(self, sep): + return [None] + + rel_invalid_url = SimpleNamespace(is_external=True, target_ref="image-no-url") + rel_request_error = SimpleNamespace(is_external=True, target_ref="https://example.com/image-error") + rel_unknown_mime = SimpleNamespace(is_external=True, target_ref="https://example.com/image-unknown") + rel_internal_none_ext = SimpleNamespace(is_external=False, target_ref=FakeTargetRef(), target_part=object()) + + doc = SimpleNamespace( + part=SimpleNamespace( + rels={ + "r1": rel_invalid_url, + "r2": rel_request_error, + "r3": rel_unknown_mime, + "r4": rel_internal_none_ext, + } + ) + ) + + def fake_get(url, **kwargs): + if "image-error" in url: + raise RuntimeError("network") + return SimpleNamespace(status_code=200, headers={"Content-Type": "application/unknown"}, content=b"x") + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=MagicMock())) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None)) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + + extractor = object.__new__(WordExtractor) + extractor.tenant_id = "tenant" + extractor.user_id = "user" + + result = extractor._extract_images_from_docx(doc) + + assert result == {} + db_stub.session.commit.assert_called_once() + + +def test_table_to_markdown_and_parse_helpers(monkeypatch): + extractor = object.__new__(WordExtractor) + + table = SimpleNamespace( + rows=[ + SimpleNamespace(cells=[1, 2]), + SimpleNamespace(cells=[3, 4]), + ] + ) + parse_row_mock = MagicMock(side_effect=[["H1", "H2"], ["A", "B"]]) + monkeypatch.setattr(extractor, "_parse_row", parse_row_mock) + + markdown = extractor._table_to_markdown(table, {}) + assert markdown == "| H1 | H2 |\n| --- | --- |\n| A | B |" + + class FakeBlip: + def __init__(self, image_id): + self.image_id = image_id + + def get(self, key): + return self.image_id + + class FakeRunChild: + def __init__(self, blips, text=""): + self._blips = blips + self.text = text + self.tag = qn("w:r") + + def xpath(self, pattern): + if pattern == ".//a:blip": + return self._blips + return [] + + class FakeRun: + def __init__(self, element, paragraph): + # Mirror the subset used by _parse_cell_paragraph + self.element = element + self.text = getattr(element, "text", "") + + # Patch we.Run so our lightweight child objects work with the extractor + monkeypatch.setattr(we, "Run", FakeRun) + + image_part = object() + paragraph = SimpleNamespace( + _element=[ + FakeRunChild([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")], text=""), + FakeRunChild([], text="plain"), + ], + part=SimpleNamespace( + rels={ + "ext": SimpleNamespace(is_external=True), + "int": SimpleNamespace(is_external=False, target_part=image_part), + } + ), + ) + + image_map = {"ext": "EXT-IMG", image_part: "INT-IMG"} + assert extractor._parse_cell_paragraph(paragraph, image_map) == "EXT-IMGINT-IMGplain" + + cell = SimpleNamespace(paragraphs=[paragraph, paragraph]) + assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain" + + +def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch): + extractor = object.__new__(WordExtractor) + + ext_image_id = "ext-image" + int_embed_id = "int-embed" + shape_ext_id = "shape-ext" + shape_int_id = "shape-int" + + internal_part = object() + shape_internal_part = object() + + class Rels(UserDict): + def get(self, key, default=None): + if key == "link-bad": + raise RuntimeError("cannot resolve relation") + return super().get(key, default) + + rels = Rels( + { + ext_image_id: SimpleNamespace(is_external=True, target_ref="https://img/ext.png"), + int_embed_id: SimpleNamespace(is_external=False, target_part=internal_part), + shape_ext_id: SimpleNamespace(is_external=True, target_ref="https://img/shape.png"), + shape_int_id: SimpleNamespace(is_external=False, target_part=shape_internal_part), + "link-ok": SimpleNamespace(is_external=True, target_ref="https://example.com"), + } + ) + + image_map = { + ext_image_id: "[EXT]", + internal_part: "[INT]", + shape_ext_id: "[SHAPE_EXT]", + shape_internal_part: "[SHAPE_INT]", + } + + class FakeBlip: + def __init__(self, embed_id): + self.embed_id = embed_id + + def get(self, key): + return self.embed_id + + class FakeDrawing: + def __init__(self, embed_ids): + self.embed_ids = embed_ids + + def findall(self, pattern): + return [FakeBlip(embed_id) for embed_id in self.embed_ids] + + class FakeNode: + def __init__(self, text=None, attrs=None): + self.text = text + self._attrs = attrs or {} + + def get(self, key): + return self._attrs.get(key) + + class FakeShape: + def __init__(self, bin_id=None, img_id=None): + self.bin_id = bin_id + self.img_id = img_id + + def find(self, pattern): + if "binData" in pattern and self.bin_id: + return FakeNode( + text="shape", + attrs={"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id": self.bin_id}, + ) + if "imagedata" in pattern and self.img_id: + return FakeNode(attrs={"id": self.img_id}) + return None + + class FakeChild: + def __init__( + self, + tag, + text="", + fld_chars=None, + instr_texts=None, + drawings=None, + shapes=None, + attrs=None, + hyperlink_runs=None, + ): + self.tag = tag + self.text = text + self._fld_chars = fld_chars or [] + self._instr_texts = instr_texts or [] + self._drawings = drawings or [] + self._shapes = shapes or [] + self._attrs = attrs or {} + self._hyperlink_runs = hyperlink_runs or [] + + def findall(self, pattern): + if pattern == qn("w:fldChar"): + return self._fld_chars + if pattern == qn("w:instrText"): + return self._instr_texts + if pattern == qn("w:r"): + return self._hyperlink_runs + if pattern.endswith("}drawing"): + return self._drawings + if pattern.endswith("}pict"): + return self._shapes + return [] + + def get(self, key): + return self._attrs.get(key) + + class FakeRun: + def __init__(self, element, paragraph): + self.element = element + self.text = getattr(element, "text", "") + + paragraph_main = SimpleNamespace( + _element=[ + FakeChild( + qn("w:r"), + text="run-text", + drawings=[FakeDrawing([ext_image_id, int_embed_id])], + shapes=[FakeShape(bin_id=shape_ext_id, img_id=shape_int_id)], + ), + FakeChild( + qn("w:r"), + text="", + drawings=[], + shapes=[FakeShape(bin_id=shape_ext_id)], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-ok"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="LinkText")], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-bad"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="BrokenLink")], + ), + ] + ) + paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")]) + + fake_doc = SimpleNamespace( + part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}), + paragraphs=[paragraph_main, paragraph_empty], + tables=[SimpleNamespace(rows=[])], + element=SimpleNamespace( + body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")] + ), + ) + + monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc) + monkeypatch.setattr(we, "Run", FakeRun) + monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map) + monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN") + logger_exception = MagicMock() + monkeypatch.setattr(we.logger, "exception", logger_exception) + + content = extractor.parse_docx("dummy.docx") + + assert "[EXT]" in content + assert "[INT]" in content + assert "[SHAPE_EXT]" in content + assert "[LinkText](https://example.com)" in content + assert "BrokenLink" in content + assert "TABLE-MARKDOWN" in content + logger_exception.assert_called_once() + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_http(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + # Build modern hyperlink inside table cell + r_id = "rIdHttp1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + # Relationship for external http link + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[Dify](https://dify.ai)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + r_id = "rIdMail1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "john@test.com" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "mailto:john@test.com", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[john@test.com](mailto:john@test.com)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py new file mode 100644 index 0000000000..26ce333e11 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py @@ -0,0 +1,300 @@ +"""Unit tests for unstructured extractors and their local/API partitioning paths.""" + +import base64 +import sys +import types +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.unstructured.unstructured_epub_extractor as epub_module +from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor +from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor +from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor +from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor +from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor +from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor +from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor +from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor + + +def _register_module(monkeypatch: pytest.MonkeyPatch, name: str, **attrs: object) -> types.ModuleType: + module = types.ModuleType(name) + for k, v in attrs.items(): + setattr(module, k, v) + monkeypatch.setitem(sys.modules, name, module) + return module + + +def _register_unstructured_packages(monkeypatch: pytest.MonkeyPatch) -> None: + _register_module(monkeypatch, "unstructured", __path__=[]) + _register_module(monkeypatch, "unstructured.partition", __path__=[]) + _register_module(monkeypatch, "unstructured.chunking", __path__=[]) + _register_module(monkeypatch, "unstructured.file_utils", __path__=[]) + + +def _install_chunk_by_title(monkeypatch: pytest.MonkeyPatch, chunks: list[SimpleNamespace]) -> None: + _register_unstructured_packages(monkeypatch) + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + return chunks + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + +class TestUnstructuredMarkdownMsgXml: + def test_markdown_extractor_without_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" chunk-1 "), SimpleNamespace(text=" chunk-2 ")]) + _register_module( + monkeypatch, "unstructured.partition.md", partition_md=lambda filename: [SimpleNamespace(text="x")] + ) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md").extract() + + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_markdown_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" via-api ")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="ignored")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "via-api" + assert calls == {"filename": "/tmp/file.md", "api_url": "https://u", "api_key": "k"} + + def test_msg_extractor_local(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + _register_module( + monkeypatch, "unstructured.partition.msg", partition_msg=lambda filename: [SimpleNamespace(text="x")] + ) + + assert UnstructuredMsgExtractor("/tmp/file.msg").extract()[0].page_content == "msg-doc" + + def test_msg_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredMsgExtractor("/tmp/file.msg", api_url="https://u", api_key="k").extract()[0].page_content + == "msg-doc" + ) + assert calls["filename"] == "/tmp/file.msg" + + def test_xml_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="xml-doc")]) + + xml_calls = {} + + def partition_xml(filename, xml_keep_tags): + xml_calls.update({"filename": filename, "xml_keep_tags": xml_keep_tags}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.xml", partition_xml=partition_xml) + + assert UnstructuredXmlExtractor("/tmp/file.xml").extract()[0].page_content == "xml-doc" + assert xml_calls == {"filename": "/tmp/file.xml", "xml_keep_tags": True} + + api_calls = {} + + def partition_via_api(filename, api_url, api_key): + api_calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredXmlExtractor("/tmp/file.xml", api_url="https://u", api_key="k").extract()[0].page_content + == "xml-doc" + ) + assert api_calls["filename"] == "/tmp/file.xml" + + +class TestUnstructuredEmailAndEpub: + def test_email_extractor_local_decodes_html_and_suppresses_decode_errors(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + captured = {} + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + captured["elements"] = list(elements) + return [SimpleNamespace(text=" chunked-email ")] + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + html = "

Hello Email

" + encoded_html = base64.b64encode(html.encode("utf-8")).decode("utf-8") + bad_base64 = "not-base64" + + elements = [SimpleNamespace(text=encoded_html), SimpleNamespace(text=bad_base64)] + _register_module(monkeypatch, "unstructured.partition.email", partition_email=lambda filename: elements) + + docs = UnstructuredEmailExtractor("/tmp/file.eml").extract() + + assert docs[0].page_content == "chunked-email" + chunk_elements = captured["elements"] + assert "Hello Email" in chunk_elements[0].text + assert chunk_elements[1].text == bad_base64 + + def test_email_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="api-email")]) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="abc")], + ) + + docs = UnstructuredEmailExtractor("/tmp/file.eml", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "api-email" + + def test_epub_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="epub-doc")]) + + calls = {"download": 0, "partition": 0} + + def fake_download_pandoc(): + calls["download"] += 1 + + def partition_epub(filename, xml_keep_tags): + calls["partition"] += 1 + assert xml_keep_tags is True + return [SimpleNamespace(text="x")] + + monkeypatch.setattr(epub_module.pypandoc, "download_pandoc", fake_download_pandoc) + _register_module(monkeypatch, "unstructured.partition.epub", partition_epub=partition_epub) + + docs = UnstructuredEpubExtractor("/tmp/file.epub").extract() + + assert docs[0].page_content == "epub-doc" + assert calls == {"download": 1, "partition": 1} + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="x")], + ) + + docs = UnstructuredEpubExtractor("/tmp/file.epub", api_url="https://u", api_key="k").extract() + assert docs[0].page_content == "epub-doc" + + +class TestUnstructuredPPTAndPPTX: + def test_ppt_extractor_requires_api_url(self): + with pytest.raises(NotImplementedError, match="Unstructured API Url is not configured"): + UnstructuredPPTExtractor("/tmp/file.ppt").extract() + + def test_ppt_extractor_groups_text_by_page(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="A", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="B", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="skip", metadata=SimpleNamespace(page_number=None)), + SimpleNamespace(text="C", metadata=SimpleNamespace(page_number=2)), + ], + ) + + docs = UnstructuredPPTExtractor("/tmp/file.ppt", api_url="https://u", api_key="k").extract() + + assert [doc.page_content for doc in docs] == ["A\nB", "C"] + + def test_pptx_extractor_local_and_api(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.pptx", + partition_pptx=lambda filename: [ + SimpleNamespace(text="P1", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="P2", metadata=SimpleNamespace(page_number=2)), + SimpleNamespace(text="Skip", metadata=SimpleNamespace(page_number=None)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx").extract() + assert [doc.page_content for doc in docs] == ["P1", "P2"] + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="X", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="Y", metadata=SimpleNamespace(page_number=1)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx", api_url="https://u", api_key="k").extract() + assert [doc.page_content for doc in docs] == ["X\nY"] + + +class TestUnstructuredWord: + def _install_doc_modules(self, monkeypatch, version: str, filetype_value): + _register_unstructured_packages(monkeypatch) + + class FileType: + DOC = "doc" + + _register_module(monkeypatch, "unstructured.__version__", __version__=version) + _register_module( + monkeypatch, + "unstructured.file_utils.filetype", + FileType=FileType, + detect_filetype=lambda filename: filetype_value, + ) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="api-doc")], + ) + _register_module( + monkeypatch, + "unstructured.partition.docx", + partition_docx=lambda filename: [SimpleNamespace(text="docx-doc")], + ) + _register_module( + monkeypatch, + "unstructured.chunking.title", + chunk_by_title=lambda elements, max_characters, combine_text_under_n_chars: [ + SimpleNamespace(text="chunk-1"), + SimpleNamespace(text="chunk-2"), + ], + ) + + def test_word_extractor_rejects_doc_on_old_unstructured_version(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="doc") + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + + def test_word_extractor_doc_and_docx_paths(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.11", filetype_value="doc") + + docs = UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + self._install_doc_modules(monkeypatch, version="0.5.0", filetype_value="not-doc") + docs = UnstructuredWordExtractor("/tmp/file.docx", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_word_extractor_magic_import_error_fallback_to_extension(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="not-used") + monkeypatch.setitem(sys.modules, "magic", None) + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() diff --git a/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py new file mode 100644 index 0000000000..d758be218a --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py @@ -0,0 +1,434 @@ +"""Unit tests for WaterCrawl client, provider, and extractor behavior.""" + +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import core.rag.extractor.watercrawl.client as client_module +from core.rag.extractor.watercrawl.client import BaseAPIClient, WaterCrawlAPIClient +from core.rag.extractor.watercrawl.exceptions import ( + WaterCrawlAuthenticationError, + WaterCrawlBadRequestError, + WaterCrawlPermissionError, +) +from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor +from core.rag.extractor.watercrawl.provider import WaterCrawlProvider + + +def _response( + status_code: int, + json_data: dict[str, Any] | None = None, + content_type: str = "application/json", + content: bytes = b"", + text: str = "", +) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.headers = {"Content-Type": content_type} + response.content = content + response.text = text + response.json.return_value = json_data if json_data is not None else {} + response.raise_for_status.return_value = None + response.close.return_value = None + return response + + +class TestWaterCrawlExceptions: + def test_bad_request_error_properties_and_string(self): + response = _response(400, {"message": "bad request", "errors": {"url": ["invalid"]}}) + + err = WaterCrawlBadRequestError(response) + parsed_errors = json.loads(err.flat_errors) + + assert err.status_code == 400 + assert err.message == "bad request" + assert "url" in parsed_errors + assert any("invalid" in str(item) for item in parsed_errors["url"]) + assert "WaterCrawlBadRequestError" in str(err) + + def test_permission_and_authentication_error_strings(self): + response = _response(403, {"message": "quota exceeded", "errors": {}}) + + permission = WaterCrawlPermissionError(response) + authentication = WaterCrawlAuthenticationError(response) + + assert "exceeding your WaterCrawl API limits" in str(permission) + assert "API key is invalid or expired" in str(authentication) + + +class TestBaseAPIClient: + def test_init_session_builds_expected_headers(self, monkeypatch): + captured = {} + + def fake_client(**kwargs): + captured.update(kwargs) + return "session" + + monkeypatch.setattr(client_module.httpx, "Client", fake_client) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client.session == "session" + assert captured["headers"]["X-API-Key"] == "k" + assert captured["headers"]["User-Agent"] == "WaterCrawl-Plugin" + + def test_request_stream_and_non_stream_paths(self, monkeypatch): + class FakeSession: + def __init__(self): + self.request_calls = [] + self.build_calls = [] + self.send_calls = [] + + def request(self, method, url, params=None, json=None, **kwargs): + self.request_calls.append((method, url, params, json, kwargs)) + return "non-stream-response" + + def build_request(self, method, url, params=None, json=None): + req = (method, url, params, json) + self.build_calls.append(req) + return req + + def send(self, request, stream=False, **kwargs): + self.send_calls.append((request, stream, kwargs)) + return "stream-response" + + fake_session = FakeSession() + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: fake_session) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client._request("GET", "/v1/items", query_params={"a": 1}) == "non-stream-response" + assert fake_session.request_calls[0][1] == "https://watercrawl.dev/v1/items" + + assert client._request("GET", "/v1/items", stream=True) == "stream-response" + assert fake_session.build_calls + assert fake_session.send_calls[0][1] is True + + def test_http_method_helpers_delegate_to_request(self, monkeypatch): + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: MagicMock()) + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + calls = [] + + def fake_request(method, endpoint, query_params=None, data=None, **kwargs): + calls.append((method, endpoint, query_params, data)) + return "ok" + + monkeypatch.setattr(client, "_request", fake_request) + + assert client._get("/a") == "ok" + assert client._post("/b", data={"x": 1}) == "ok" + assert client._put("/c", data={"x": 2}) == "ok" + assert client._delete("/d") == "ok" + assert client._patch("/e", data={"x": 3}) == "ok" + assert [c[0] for c in calls] == ["GET", "POST", "PUT", "DELETE", "PATCH"] + + +class TestWaterCrawlAPIClient: + def test_process_eventstream_and_download(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = MagicMock() + response.iter_lines.return_value = [ + b"event: keep-alive", + b'data: {"type":"result","data":{"result":"http://x"}}', + b'data: {"type":"log","data":{"msg":"ok"}}', + ] + + monkeypatch.setattr(client, "download_result", lambda data: {"result": {"markdown": "body"}, "url": "u"}) + + events = list(client.process_eventstream(response, download=True)) + + assert events[0]["data"]["result"]["markdown"] == "body" + assert events[1]["type"] == "log" + response.close.assert_called_once() + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (401, WaterCrawlAuthenticationError), + (403, WaterCrawlPermissionError), + (422, WaterCrawlBadRequestError), + ], + ) + def test_process_response_error_statuses(self, status: int, expected_exception: type[Exception]): + client = WaterCrawlAPIClient(api_key="k") + + with pytest.raises(expected_exception): + client.process_response(_response(status, {"message": "bad", "errors": {"url": ["x"]}})) + + def test_process_response_204_returns_none(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(204, None)) is None + + def test_process_response_json_payloads(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(200, {"ok": True})) == {"ok": True} + assert client.process_response(_response(200, None)) == {} + + def test_process_response_octet_stream_returns_bytes(self): + client = WaterCrawlAPIClient(api_key="k") + assert ( + client.process_response(_response(200, content_type="application/octet-stream", content=b"bin")) == b"bin" + ) + + def test_process_response_event_stream_returns_generator(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + generator = (item for item in [{"type": "result", "data": {}}]) + monkeypatch.setattr(client, "process_eventstream", lambda response, download=False: generator) + assert client.process_response(_response(200, content_type="text/event-stream")) is generator + + def test_process_response_unknown_content_type_raises(self): + client = WaterCrawlAPIClient(api_key="k") + with pytest.raises(Exception, match="Unknown response type"): + client.process_response(_response(200, content_type="text/plain", text="x")) + + def test_process_response_uses_raise_for_status(self): + client = WaterCrawlAPIClient(api_key="k") + response = _response(500, {"message": "server"}) + response.raise_for_status.side_effect = RuntimeError("http error") + + with pytest.raises(RuntimeError, match="http error"): + client.process_response(response) + + def test_endpoint_wrappers(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda resp: "processed") + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "get-resp") + monkeypatch.setattr(client, "_post", lambda *args, **kwargs: "post-resp") + monkeypatch.setattr(client, "_delete", lambda *args, **kwargs: "delete-resp") + + assert client.get_crawl_requests_list() == "processed" + assert client.get_crawl_request("id") == "processed" + assert client.create_crawl_request(url="https://x") == "processed" + assert client.stop_crawl_request("id") == "processed" + assert client.download_crawl_request("id") == "processed" + assert client.get_crawl_request_results("id") == "processed" + + def test_monitor_crawl_request_generator_and_validation(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda _: (x for x in [{"type": "result", "data": 1}])) + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "stream-resp") + + events = list(client.monitor_crawl_request("job-1", prefetched=True)) + assert events == [{"type": "result", "data": 1}] + + monkeypatch.setattr(client, "process_response", lambda _: [{"type": "result"}]) + with pytest.raises(ValueError, match="Generator expected"): + list(client.monitor_crawl_request("job-1")) + + def test_scrape_url_sync_and_async(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + monkeypatch.setattr(client, "create_crawl_request", lambda **kwargs: {"uuid": "job-1"}) + + async_result = client.scrape_url("https://example.com", sync=False) + assert async_result == {"uuid": "job-1"} + + monkeypatch.setattr( + client, + "monitor_crawl_request", + lambda item_id, prefetched: iter( + [{"type": "log", "data": {}}, {"type": "result", "data": {"url": "https://example.com"}}] + ), + ) + sync_result = client.scrape_url("https://example.com", sync=True) + assert sync_result == {"url": "https://example.com"} + + def test_download_result_fetches_json_and_closes(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = _response(200, {"markdown": "body"}) + monkeypatch.setattr(client_module.httpx, "get", lambda *args, **kwargs: response) + + result = client.download_result({"result": "https://example.com/result.json"}) + + assert result["result"] == {"markdown": "body"} + response.close.assert_called_once() + + +class TestWaterCrawlProvider: + def test_crawl_url_builds_options_and_min_wait_time(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + captured_kwargs = {} + + def create_crawl_request_spy(**kwargs): + captured_kwargs.update(kwargs) + return {"uuid": "job-1"} + + monkeypatch.setattr(provider.client, "create_crawl_request", create_crawl_request_spy) + + result = provider.crawl_url( + "https://example.com", + { + "crawl_sub_pages": True, + "limit": 5, + "max_depth": 2, + "includes": "a,b", + "excludes": "x,y", + "exclude_tags": "nav,footer", + "include_tags": "main", + "wait_time": 100, + "only_main_content": False, + }, + ) + + assert result == {"status": "active", "job_id": "job-1"} + assert captured_kwargs["url"] == "https://example.com" + assert captured_kwargs["spider_options"] == { + "max_depth": 2, + "page_limit": 5, + "allowed_domains": [], + "exclude_paths": ["x", "y"], + "include_paths": ["a", "b"], + } + assert captured_kwargs["page_options"]["exclude_tags"] == ["nav", "footer"] + assert captured_kwargs["page_options"]["include_tags"] == ["main"] + assert captured_kwargs["page_options"]["only_main_content"] is False + assert captured_kwargs["page_options"]["wait_time"] == 1000 + + def test_get_crawl_status_active_and_completed(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "running", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 3}}, + "number_of_documents": 1, + "duration": "00:00:01.500000", + }, + ) + + active = provider.get_crawl_status("job-1") + assert active["status"] == "active" + assert active["data"] == [] + assert active["time_consuming"] == pytest.approx(1.5) + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "completed", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 2}}, + "number_of_documents": 2, + "duration": "00:00:02.000000", + }, + ) + monkeypatch.setattr(provider, "_get_results", lambda crawl_request_id, query_params=None: iter([{"url": "u"}])) + + completed = provider.get_crawl_status("job-2") + assert completed["status"] == "completed" + assert completed["data"] == [{"url": "u"}] + + def test_get_crawl_url_data_and_scrape(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr(provider, "scrape_url", lambda url: {"source_url": url}) + assert provider.get_crawl_url_data("", "https://example.com") == {"source_url": "https://example.com"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([{"source_url": "u1"}])) + assert provider.get_crawl_url_data("job", "u1") == {"source_url": "u1"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([])) + assert provider.get_crawl_url_data("job", "u1") is None + + def test_structure_data_validation_and_get_results_pagination(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + with pytest.raises(ValueError, match="Invalid result object"): + provider._structure_data({"result": "not-a-dict"}) + + structured = provider._structure_data( + { + "url": "https://example.com", + "result": { + "metadata": {"title": "Title", "description": "Desc"}, + "markdown": "Body", + }, + } + ) + assert structured["title"] == "Title" + assert structured["markdown"] == "Body" + + responses = [ + { + "results": [ + { + "url": "https://a", + "result": {"metadata": {"title": "A", "description": "DA"}, "markdown": "MA"}, + } + ], + "next": "next-page", + }, + {"results": [], "next": None}, + ] + + monkeypatch.setattr( + provider.client, + "get_crawl_request_results", + lambda crawl_request_id, page, page_size, query_params: responses.pop(0), + ) + + results = list(provider._get_results("job-1")) + assert len(results) == 1 + assert results[0]["source_url"] == "https://a" + + def test_scrape_url_uses_client_and_structure(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + monkeypatch.setattr( + provider.client, "scrape_url", lambda **kwargs: {"result": {"metadata": {}, "markdown": "m"}, "url": "u"} + ) + + result = provider.scrape_url("u") + + assert result["source_url"] == "u" + + +class TestWaterCrawlWebExtractor: + def test_extract_crawl_and_scrape_modes(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: { + "markdown": "crawl", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_scrape_url_data", + lambda provider, url, tenant_id, only_main_content: { + "markdown": "scrape", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + + crawl_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + scrape_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert crawl_extractor.extract()[0].page_content == "crawl" + assert scrape_extractor.extract()[0].page_content == "scrape" + + def test_extract_crawl_returns_empty_when_service_returns_none(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: None, + ) + + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_unknown_mode_returns_empty(self): + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="other") + + assert extractor.extract() == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/conftest.py b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py new file mode 100644 index 0000000000..2a3860e107 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py @@ -0,0 +1,33 @@ +from contextlib import AbstractContextManager, nullcontext +from typing import Any + +import pytest + + +class _FakeFlaskApp: + def app_context(self) -> AbstractContextManager[None]: + return nullcontext() + + +class _FakeExecutor: + def __init__(self, future: Any) -> None: + self._future = future + + def __enter__(self) -> "_FakeExecutor": + return self + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool: + return False + + def submit(self, func: object, preview: object) -> Any: + return self._future + + +@pytest.fixture +def fake_flask_app() -> _FakeFlaskApp: + return _FakeFlaskApp() + + +@pytest.fixture +def fake_executor_cls() -> type[_FakeExecutor]: + return _FakeExecutor diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py new file mode 100644 index 0000000000..2451db70b6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -0,0 +1,629 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from dify_graph.model_runtime.entities.model_entities import ModelFeature + + +class TestParagraphIndexProcessor: + @pytest.fixture + def processor(self) -> ParagraphIndexProcessor: + return ParagraphIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def _llm_result(self, content: str = "summary") -> LLMResult: + return LLMResult( + model="llm-model", + message=AssistantPromptMessage(content=content), + usage=LLMUsage.empty_usage(), + ) + + def test_extract_forwards_automatic_flag(self, processor: ParagraphIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected_docs + docs = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParagraphIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + rules_without_segmentation = SimpleNamespace(segmentation=None) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=rules_without_segmentation, + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".first", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + return_value=".first", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + ): + documents = processor.transform([source_document], process_rule=process_rule) + + assert len(documents) == 1 + assert documents[0].page_content == "first" + assert documents[0].attachments is not None + assert documents[0].metadata["doc_hash"] == "hash" + + def test_transform_automatic_mode_uses_default_rules(self, processor: ParagraphIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content="text", metadata={})] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ) as mock_validate, + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_get_content_files", return_value=[]), + ): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "automatic"}) + + assert mock_validate.call_count == 1 + + def test_load_creates_vector_and_multimodal_when_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + docs = [Document(page_content="chunk", metadata={})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + vector = mock_vector_cls.return_value + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + mock_keyword_cls.assert_not_called() + + def test_load_uses_keyword_add_texts_with_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs, keywords_list=["k1", "k2"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs, keywords_list=["k1", "k2"]) + + def test_load_uses_keyword_add_texts_without_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs) + + def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_economy_deletes_summaries_and_keywords( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + mock_keyword_cls.return_value.delete.assert_called_once() + + def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.clean(dataset, ["node-2"], with_keywords=True) + + mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"]) + + def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9) + rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [accepted, rejected] + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + def test_index_list_chunks_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, ["chunk-1", "chunk-2"]) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_list_chunks_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.index(dataset, dataset_document, ["chunk-3"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once() + + def test_index_multimodal_structure_handles_files_and_account_lookup( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + chunk_with_files = SimpleNamespace( + content="content-1", + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + chunk_without_files = SimpleNamespace(content="content-2", files=None) + structure = SimpleNamespace(general_chunks=[chunk_with_files, chunk_without_files]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector"), + ): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + assert mock_files.call_count == 1 + + def test_index_multimodal_structure_requires_valid_account( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + structure = SimpleNamespace(general_chunks=[SimpleNamespace(content="content", files=None)]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + def test_format_preview_validates_chunk_shape(self, processor: ParagraphIndexProcessor) -> None: + preview = processor.format_preview(["chunk-1", "chunk-2"]) + assert preview["chunk_structure"] == "text_model" + assert preview["total_segments"] == 2 + + with pytest.raises(ValueError, match="Chunks is not a list"): + processor.format_preview({"not": "a-list"}) + + def test_generate_summary_preview_success_and_failure(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())): + result = processor.generate_summary_preview( + "tenant-1", preview_items, {"enable": True}, doc_language="English" + ) + assert all(item.summary == "summary" for item in result) + + with patch.object(processor, "generate_summary", side_effect=RuntimeError("summary failed")): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", [PreviewDetail(content="chunk-1")], {"enable": True}) + + def test_generate_summary_preview_fallback_without_flask_context(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())), + ): + result = processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_timeout( + self, processor: ParagraphIndexProcessor, fake_executor_cls: type + ) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + future.cancel.assert_called_once() + + def test_generate_summary_validates_input(self) -> None: + with pytest.raises(ValueError, match="must be enabled"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": False}) + + with pytest.raises(ValueError, match="model_name and model_provider_name"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True}) + + def test_generate_summary_text_only_flow(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + model_instance.invoke_llm.return_value = self._llm_result("text summary") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota", + side_effect=RuntimeError("quota"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, usage = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + document_language="English", + ) + + assert summary == "text summary" + assert isinstance(usage, LLMUsage) + mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota") + + def test_generate_summary_handles_vision_and_image_conversion(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = self._llm_result("vision summary") + image_file = SimpleNamespace() + image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch.object( + ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[image_file] + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[]) as mock_extract_text, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + return_value=image_content, + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, _ = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + segment_id="seg-1", + ) + + assert summary == "vision summary" + mock_extract_text.assert_not_called() + + def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = object() + image_file = SimpleNamespace() + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DEFAULT_GENERATOR_SUMMARY_PROMPT", + "Prompt {missing}", + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[]), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[image_file]), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + side_effect=RuntimeError("bad image"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + with pytest.raises(ValueError, match="Expected LLMResult"): + ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + ) + + mock_logger.warning.assert_called_with( + "Failed to convert image file to prompt message content: %s", "bad image" + ) + + def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None: + text = ( + "![img](/files/11111111-1111-1111-1111-111111111111/image-preview) " + "![img2](/files/22222222-2222-2222-2222-222222222222/file-preview) " + "![tool](/files/tools/33333333-3333-3333-3333-333333333333.png)" + ) + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + non_image_upload = SimpleNamespace( + id="22222222-2222-2222-2222-222222222222", + tenant_id="tenant-1", + name="file.txt", + mime_type="text/plain", + extension="txt", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload, non_image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + return_value=SimpleNamespace(id="file-1"), + ) as mock_builder, + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert len(files) == 1 + assert mock_builder.call_count == 1 + mock_logger.warning.assert_not_called() + + def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None: + assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here") == [] + + def test_extract_images_from_text_logs_when_build_fails(self) -> None: + text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)" + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + side_effect=RuntimeError("build failed"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert files == [] + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments(self) -> None: + image_upload = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="", + size=1, + key="k1", + ) + bad_upload = SimpleNamespace( + id="file-2", + name="broken", + extension=None, + mime_type="image/png", + source_url="", + size=1, + key="k2", + ) + non_image_upload = SimpleNamespace( + id="file-3", + name="text", + extension="txt", + mime_type="text/plain", + source_url="", + size=1, + key="k3", + ) + execute_result = Mock() + execute_result.all.return_value = [(None, image_upload), (None, bad_upload), (None, non_image_upload)] + session = Mock() + session.execute.return_value = execute_result + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert len(files) == 1 + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments_empty(self) -> None: + execute_result = Mock() + execute_result.all.return_value = [] + session = Mock() + session.execute.return_value = execute_result + + with patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session): + empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert empty_files == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py new file mode 100644 index 0000000000..abe40f05d1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -0,0 +1,523 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.models.document import AttachmentDocument, ChildDocument, Document +from services.entities.knowledge_entities.knowledge_entities import ParentMode + + +class TestParentChildIndexProcessor: + @pytest.fixture + def processor(self) -> ParentChildIndexProcessor: + return ParentChildIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + document.dataset_process_rule_id = None + return document + + def _segmentation(self) -> SimpleNamespace: + return SimpleNamespace(max_tokens=200, chunk_overlap=10, separator="\n") + + def _paragraph_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + segmentation=self._segmentation(), + subchunk_segmentation=self._segmentation(), + ) + + def _full_doc_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.FULL_DOC, segmentation=None, subchunk_segmentation=self._segmentation() + ) + + def test_extract_forwards_automatic_flag(self, processor: ParentChildIndexProcessor) -> None: + extract_setting = Mock() + expected = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected + documents = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert documents == expected + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParentChildIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_paragraph_requires_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(parent_mode=ParentMode.PARAGRAPH, segmentation=None) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", return_value=rules + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_paragraph_builds_parent_and_child_docs(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".parent", metadata={}), + Document(page_content=" ", metadata={}), + ] + parent_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + child_docs = [ChildDocument(page_content="child-1", metadata={"dataset_id": "dataset-1"})] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + return_value=".parent", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + ): + result = processor.transform( + [parent_document], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=False, + ) + + assert len(result) == 1 + assert result[0].page_content == "parent" + assert result[0].children == child_docs + assert result[0].attachments is not None + + def test_transform_preview_returns_after_ten_parent_chunks(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content=f"chunk-{i}", metadata={}) for i in range(10)] + documents = [ + Document(page_content="doc-1", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"dataset_id": "dataset-1", "document_id": "doc-2"}), + ] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch.object(processor, "_split_child_nodes", return_value=[]), + ): + result = processor.transform( + documents, + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 10 + + def test_transform_full_doc_mode_trims_children_for_preview(self, processor: ParentChildIndexProcessor) -> None: + docs = [ + Document(page_content="first", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="second", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + ] + child_docs = [ChildDocument(page_content=f"child-{i}", metadata={}) for i in range(5)] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._full_doc_rules(), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.dify_config.CHILD_CHUNKS_PREVIEW_NUMBER", + 2, + ), + ): + result = processor.transform( + docs, + process_rule={"mode": "hierarchical", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 1 + assert len(result[0].children or []) == 2 + assert result[0].attachments is not None + + def test_load_creates_vectors_for_child_docs(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + parent_doc = Document( + page_content="parent", + metadata={}, + children=[ + ChildDocument(page_content="child-1", metadata={}), + ChildDocument(page_content="child-2", metadata={}), + ], + ) + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, [parent_doc], multimodal_documents=multimodal_docs) + + assert vector.create.call_count == 1 + formatted_docs = vector.create.call_args[0][0] + assert len(formatted_docs) == 2 + assert all(isinstance(doc, Document) for doc in formatted_docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + delete_query = Mock() + where_query = Mock() + where_query.delete.return_value = 2 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean( + dataset, + ["node-1"], + delete_child_chunks=True, + precomputed_child_node_ids=["child-1", "child-2"], + ) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_queries_child_ids_when_not_precomputed( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + child_query = Mock() + child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)] + session = Mock() + session.query.return_value = child_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_child_chunks=False) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + + def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + where_query = Mock() + where_query.delete.return_value = 3 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_child_chunks=True) + + vector.delete.assert_called_once() + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.session_factory.create_session", + return_value=session_ctx, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, ["node-1"], delete_summaries=True, precomputed_child_node_ids=[]) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + + def test_clean_deletes_all_summaries_when_node_ids_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + + def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8) + low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [ok_result, low_result] + docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].page_content == "keep" + assert docs[0].metadata["score"] == 0.8 + + def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=None) + + with pytest.raises(ValueError, match="No subchunk segmentation found"): + processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None) + + def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=self._segmentation()) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".child-1", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + ): + child_docs = processor._split_child_nodes( + Document(page_content="parent", metadata={}), rules, "custom", None + ) + + assert len(child_docs) == 1 + assert child_docs[0].page_content == "child-1" + assert child_docs[0].metadata["doc_hash"] == "hash" + + def test_index_creates_process_rule_segments_and_vectors( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[ + SimpleNamespace( + parent_content="parent text", + child_contents=["child-1", "child-2"], + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + ], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + side_effect=lambda text: f"hash-{text}", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + assert dataset_document.dataset_process_rule_id == "rule-1" + session.add.assert_called_once_with(dataset_rule) + session.flush.assert_called_once() + session.commit.assert_called_once() + mock_store_cls.return_value.add_documents.assert_called_once() + assert mock_vector_cls.return_value.create.call_count == 1 + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_uses_content_files_when_files_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + mock_files.assert_called_once() + + def test_index_raises_when_account_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + def test_format_preview_returns_parent_child_structure(self, processor: ParentChildIndexProcessor) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child-1", "child-2"])], + ) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ): + preview = processor.format_preview({"parent_child_chunks": []}) + + assert preview["chunk_structure"] == "hierarchical_model" + assert preview["parent_mode"] == ParentMode.PARAGRAPH + assert preview["total_segments"] == 1 + + def test_generate_summary_preview_sets_summaries(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ): + result = processor.generate_summary_preview( + "tenant-1", preview_texts, {"enable": True}, doc_language="English" + ) + + assert all(item.summary == "summary" for item in result) + + def test_generate_summary_preview_raises_when_worker_fails(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + side_effect=RuntimeError("summary failed"), + ): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + def test_generate_summary_preview_falls_back_without_flask_context( + self, processor: ParentChildIndexProcessor + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ), + ): + result = processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_handles_timeout( + self, processor: ParentChildIndexProcessor, fake_executor_cls: type + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + future.cancel.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py new file mode 100644 index 0000000000..8596647ef3 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -0,0 +1,382 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pandas as pd +import pytest +from werkzeug.datastructures import FileStorage + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ImmediateThread: + def __init__(self, target, args=(), kwargs=None): + self._target = target + self._args = args + self._kwargs = kwargs or {} + + def start(self) -> None: + self._target(*self._args, **self._kwargs) + + def join(self) -> None: + return None + + +class TestQAIndexProcessor: + @pytest.fixture + def processor(self) -> QAIndexProcessor: + return QAIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def test_extract_forwards_automatic_flag(self, processor: QAIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.ExtractProcessor.extract") as mock_extract: + mock_extract.return_value = expected_docs + + docs = processor.extract(extract_setting, process_rule_mode="automatic") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_rejects_none_process_rule(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + def test_transform_rejects_missing_rules_key(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_preview_calls_formatter_once( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + split_node = Document(page_content=".question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content="Q1", metadata={"answer": "A1"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", return_value="clean text" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform( + [document], + process_rule=process_rule, + preview=True, + tenant_id="tenant-1", + doc_language="English", + ) + + assert len(result) == 1 + assert result[0].metadata["answer"] == "A1" + mock_format.assert_called_once() + + def test_transform_non_preview_uses_thread_batches( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + documents = [ + Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), + Document(page_content="doc-2", metadata={"document_id": "doc-2", "dataset_id": "dataset-1"}), + ] + split_node = Document(page_content="question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content=f"Q-{document_node.page_content}", metadata={"answer": "A"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + patch( + "core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread + ), + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1") + + assert len(result) == 2 + assert mock_format.call_count == 2 + + def test_format_by_template_validates_file_type(self, processor: QAIndexProcessor) -> None: + not_csv_file = Mock(spec=FileStorage) + not_csv_file.filename = "qa.txt" + + with pytest.raises(ValueError, match="Only CSV files"): + processor.format_by_template(not_csv_file) + + def test_format_by_template_parses_csv_rows(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + dataframe = pd.DataFrame([["Q1", "A1"], ["Q2", "A2"]]) + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=dataframe): + docs = processor.format_by_template(csv_file) + + assert [doc.page_content for doc in docs] == ["Q1", "Q2"] + assert [doc.metadata["answer"] for doc in docs] == ["A1", "A2"] + + def test_format_by_template_raises_on_empty_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=pd.DataFrame()): + with pytest.raises(ValueError, match="empty"): + processor.format_by_template(csv_file) + + def test_format_by_template_raises_on_invalid_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch( + "core.rag.index_processor.processor.qa_index_processor.pd.read_csv", side_effect=Exception("bad csv") + ): + with pytest.raises(ValueError, match="bad csv"): + processor.format_by_template(csv_file) + + def test_load_creates_vectors_for_high_quality_dataset(self, processor: QAIndexProcessor, dataset: Mock) -> None: + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + processor.load(dataset, docs) + + mock_vector_cls.assert_not_called() + + def test_clean_handles_summary_deletion_and_vector_cleanup( + self, processor: QAIndexProcessor, dataset: Mock + ) -> None: + mock_segment = SimpleNamespace(id="seg-1") + mock_query = Mock() + mock_query.filter.return_value.all.return_value = [mock_segment] + mock_session = Mock() + mock_session.query.return_value = mock_query + session_context = MagicMock() + session_context.__enter__.return_value = mock_session + session_context.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.session_factory.create_session", + return_value=session_context, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_handles_dataset_wide_cleanup(self, processor: QAIndexProcessor, dataset: Mock) -> None: + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + vector.delete.assert_called_once() + + def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None: + result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9) + result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1) + + with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: + mock_retrieve.return_value = [result_ok, result_low] + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].page_content == "accepted" + assert docs[0].metadata["score"] == 0.9 + + def test_index_adds_documents_and_vectors_for_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + qa_chunks = SimpleNamespace( + qa_chunks=[ + SimpleNamespace(question="Q1", answer="A1"), + SimpleNamespace(question="Q2", answer="A2"), + ] + ) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore") as mock_store_cls, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + + def test_index_requires_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore"), + ): + with pytest.raises(ValueError, match="must be high quality"): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + def test_format_preview_returns_qa_preview(self, processor: QAIndexProcessor) -> None: + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ): + preview = processor.format_preview({"qa_chunks": []}) + + assert preview["chunk_structure"] == "qa_model" + assert preview["total_segments"] == 1 + assert preview["qa_preview"] == [{"question": "Q1", "answer": "A1"}] + + def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None: + preview_items = [PreviewDetail(content="Q1")] + assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items + + def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + blank_document = Document(page_content=" ", metadata={}) + + processor._format_qa_document(fake_flask_app, "tenant-1", blank_document, all_qa_documents, "English") + + assert all_qa_documents == [] + + def test_format_qa_document_builds_question_answer_documents( + self, processor: QAIndexProcessor, fake_flask_app + ) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + return_value="Q1: What is this?\nA1: A test.\nQ2: Why?\nA2: Coverage.", + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert len(all_qa_documents) == 2 + assert all_qa_documents[0].page_content == "What is this?" + assert all_qa_documents[0].metadata["answer"] == "A test." + assert all_qa_documents[1].metadata["answer"] == "Coverage." + + def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + side_effect=RuntimeError("llm failure"), + ), + patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger, + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert all_qa_documents == [] + mock_logger.exception.assert_called_once_with("Failed to format qa document") + + def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None: + parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n") + + assert parsed == [{"question": "First?", "answer": "One."}, {"question": "Second?", "answer": "Two."}] diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py new file mode 100644 index 0000000000..b31bb6eea7 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -0,0 +1,291 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import httpx +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ForwardingBaseIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting, **kwargs): + return super().extract(extract_setting, **kwargs) + + def transform(self, documents, current_user=None, **kwargs): + return super().transform(documents, current_user=current_user, **kwargs) + + def generate_summary_preview(self, tenant_id, preview_texts, summary_index_setting, doc_language=None): + return super().generate_summary_preview( + tenant_id=tenant_id, + preview_texts=preview_texts, + summary_index_setting=summary_index_setting, + doc_language=doc_language, + ) + + def load(self, dataset, documents, multimodal_documents=None, with_keywords=True, **kwargs): + return super().load( + dataset=dataset, + documents=documents, + multimodal_documents=multimodal_documents, + with_keywords=with_keywords, + **kwargs, + ) + + def clean(self, dataset, node_ids, with_keywords=True, **kwargs): + return super().clean(dataset=dataset, node_ids=node_ids, with_keywords=with_keywords, **kwargs) + + def index(self, dataset, document, chunks): + return super().index(dataset=dataset, document=document, chunks=chunks) + + def format_preview(self, chunks): + return super().format_preview(chunks) + + def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model): + return super().retrieve( + retrieval_method=retrieval_method, + query=query, + dataset=dataset, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + + +class TestBaseIndexProcessor: + @pytest.fixture + def processor(self) -> _ForwardingBaseIndexProcessor: + return _ForwardingBaseIndexProcessor() + + def test_abstract_methods_raise_not_implemented(self, processor: _ForwardingBaseIndexProcessor) -> None: + with pytest.raises(NotImplementedError): + processor.extract(Mock()) + with pytest.raises(NotImplementedError): + processor.transform([]) + with pytest.raises(NotImplementedError): + processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {}) + with pytest.raises(NotImplementedError): + processor.load(Mock(), []) + with pytest.raises(NotImplementedError): + processor.clean(Mock(), None) + with pytest.raises(NotImplementedError): + processor.index(Mock(), Mock(), {}) + with pytest.raises(NotImplementedError): + processor.format_preview([]) + with pytest.raises(NotImplementedError): + processor.retrieve("semantic_search", "q", Mock(), 3, 0.5, {}) + + def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None: + with patch( + "core.rag.index_processor.index_processor_base.dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH", 1000 + ): + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 49, 0, "", None) + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 1001, 0, "", None) + + def test_get_splitter_custom_mode_uses_fixed_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + fixed_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.FixedRecursiveCharacterTextSplitter.from_encoder", + return_value=fixed_splitter, + ) as mock_fixed: + splitter = processor._get_splitter("hierarchical", 120, 10, "\\n\\n", None) + + assert splitter is fixed_splitter + assert mock_fixed.call_args.kwargs["fixed_separator"] == "\n\n" + assert mock_fixed.call_args.kwargs["chunk_size"] == 120 + + def test_get_splitter_automatic_mode_uses_enhance_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + auto_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.EnhanceRecursiveCharacterTextSplitter.from_encoder", + return_value=auto_splitter, + ) as mock_enhance: + splitter = processor._get_splitter("automatic", 0, 0, "", None) + + assert splitter is auto_splitter + assert "chunk_size" in mock_enhance.call_args.kwargs + + def test_extract_markdown_images(self, processor: _ForwardingBaseIndexProcessor) -> None: + markdown = "text ![a](https://a/img.png) and ![b](/files/123/file-preview)" + images = processor._extract_markdown_images(markdown) + assert images == ["https://a/img.png", "/files/123/file-preview"] + + def test_get_content_files_without_images_returns_empty(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="no image markdown", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + assert processor._get_content_files(document) == [] + + def test_get_content_files_handles_all_sources_and_duplicates( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = [ + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb/file-preview", + "/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", + "https://example.com/remote.png?x=1", + ] + upload_a = SimpleNamespace(id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", name="a.png") + upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png") + upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png") + upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png") + db_query = Mock() + db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch.object(processor, "_download_tool_file", return_value="tool-upload-id") as mock_tool_download, + patch.object(processor, "_download_image", return_value="remote-upload-id") as mock_image_download, + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document, current_user=Mock()) + + assert len(files) == 5 + assert all(isinstance(file, AttachmentDocument) for file in files) + assert files[0].metadata["doc_type"] == DocType.IMAGE + assert files[0].metadata["document_id"] == "doc-1" + assert files[0].metadata["dataset_id"] == "ds-1" + assert files[0].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + assert files[1].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_tool_download.assert_called_once() + mock_image_download.assert_called_once() + + def test_get_content_files_skips_tool_and_remote_download_without_user( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", "https://example.com/remote.png"] + + with patch.object(processor, "_extract_markdown_images", return_value=images): + files = processor._get_content_files(document, current_user=None) + + assert files == [] + + def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"] + db_query = Mock() + db_query.where.return_value.all.return_value = [] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document) + + assert files == [] + + def test_download_image_success_with_filename_from_content_disposition( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + response = Mock() + response.headers = { + "Content-Length": "4", + "content-disposition": "attachment; filename=test-image.png", + "content-type": "image/png", + } + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"data"] + upload_result = SimpleNamespace(id="upload-id") + + mock_db = Mock() + mock_db.engine = Mock() + + with ( + patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response), + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + upload_id = processor._download_image("https://example.com/test.png", current_user=Mock()) + + assert upload_id == "upload-id" + mock_file_service.return_value.upload_file.assert_called_once() + + def test_download_image_validates_size_and_empty_content(self, processor: _ForwardingBaseIndexProcessor) -> None: + too_large = Mock() + too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"} + too_large.raise_for_status.return_value = None + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large): + assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None + + empty = Mock() + empty.headers = {"Content-Length": "0", "content-type": "image/png"} + empty.raise_for_status.return_value = None + empty.iter_bytes.return_value = [] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty): + assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None + + def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None: + response = Mock() + response.headers = {"content-type": "image/png"} + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response): + assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None + + def test_download_image_handles_timeout_request_and_unexpected_errors( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + request = httpx.Request("GET", "https://example.com/image.png") + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.TimeoutException("timeout"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.RequestError("bad request", request=request), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=RuntimeError("unexpected"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + db_query = Mock() + db_query.where.return_value.first.return_value = None + db_session = Mock() + db_session.query.return_value = db_query + + with patch("core.rag.index_processor.index_processor_base.db.session", db_session): + assert processor._download_tool_file("tool-id", current_user=Mock()) is None + + def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png") + db_query = Mock() + db_query.where.return_value.first.return_value = tool_file + db_session = Mock() + db_session.query.return_value = db_query + mock_db = Mock() + mock_db.session = db_session + mock_db.engine = Mock() + upload_result = SimpleNamespace(id="upload-id") + + with ( + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("core.rag.index_processor.index_processor_base.storage.load_once", return_value=b"blob") as mock_load, + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + result = processor._download_tool_file("tool-id", current_user=Mock()) + + assert result == "upload-id" + mock_load.assert_called_once_with("k1") + mock_file_service.return_value.upload_file.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py new file mode 100644 index 0000000000..0fc666dbbf --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py @@ -0,0 +1,42 @@ +import pytest + +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor + + +class TestIndexProcessorFactory: + def test_requires_index_type(self) -> None: + factory = IndexProcessorFactory(index_type=None) + + with pytest.raises(ValueError, match="Index type must be specified"): + factory.init_index_processor() + + def test_builds_paragraph_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARAGRAPH_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParagraphIndexProcessor) + + def test_builds_qa_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.QA_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, QAIndexProcessor) + + def test_builds_parent_child_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARENT_CHILD_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParentChildIndexProcessor) + + def test_rejects_unsupported_index_type(self) -> None: + factory = IndexProcessorFactory(index_type="unsupported") + + with pytest.raises(ValueError, match="is not supported"): + factory.init_index_processor() diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 0e53482c51..b150d677f1 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -12,13 +12,18 @@ All tests use mocking to avoid external dependencies and ensure fast, reliable e Tests follow the Arrange-Act-Assert pattern for clarity. """ +from operator import itemgetter +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest from core.model_manager import ModelInstance +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode @@ -26,7 +31,7 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -def create_mock_model_instance(): +def create_mock_model_instance() -> ModelInstance: """Create a properly configured mock ModelInstance for reranking tests.""" mock_instance = Mock(spec=ModelInstance) # Setup provider_model_bundle chain for check_model_support_vision @@ -59,14 +64,7 @@ class TestRerankModelRunner: @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" - mock_instance = Mock(spec=ModelInstance) - # Setup provider_model_bundle chain for check_model_support_vision - mock_instance.provider_model_bundle = Mock() - mock_instance.provider_model_bundle.configuration = Mock() - mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" - mock_instance.provider = "test-provider" - mock_instance.model_name = "test-model" - return mock_instance + return create_mock_model_instance() @pytest.fixture def rerank_runner(self, mock_model_instance): @@ -382,6 +380,206 @@ class TestRerankModelRunner: assert call_kwargs["user"] == "user123" +class _ForwardingBaseRerankRunner(BaseRerankRunner): + def run( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> list[Document]: + return super().run( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=top_n, + user=user, + query_type=query_type, + ) + + +class TestBaseRerankRunner: + def test_run_raises_not_implemented(self): + runner = _ForwardingBaseRerankRunner() + + with pytest.raises(NotImplementedError): + runner.run(query="python", documents=[]) + + +class TestRerankModelRunnerMultimodal: + @pytest.fixture + def mock_model_instance(self): + return create_mock_model_instance() + + @pytest.fixture + def rerank_runner(self, mock_model_instance): + return RerankModelRunner(rerank_model_instance=mock_model_instance) + + def test_run_returns_original_documents_for_non_text_query_without_vision_support( + self, rerank_runner, mock_model_instance + ): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), + ] + + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) + + assert result == documents + mock_model_instance.invoke_rerank.assert_not_called() + + def test_run_uses_multimodal_path_when_vision_support_is_enabled(self, rerank_runner): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1", "source": "wiki"}, provider="dify"), + ] + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="doc", score=0.88)], + ) + + with ( + patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch.object( + rerank_runner, + "fetch_multimodal_rerank", + return_value=(rerank_result, documents), + ) as mock_multimodal, + ): + mock_mm.return_value.check_model_support_vision.return_value = True + result = rerank_runner.run(query="python", documents=documents, query_type=QueryType.TEXT_QUERY) + + mock_multimodal.assert_called_once() + assert len(result) == 1 + assert result[0].metadata["score"] == 0.88 + + def test_fetch_multimodal_rerank_builds_docs_and_calls_text_rerank(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-1", "doc_type": DocType.IMAGE}, + provider="dify", + ) + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + external_doc = Document( + page_content="external-content", + metadata={}, + provider="external", + ) + query = Mock() + query.where.return_value.first.return_value = SimpleNamespace(key="image-key") + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once, + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc, text_doc, external_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc, text_doc, external_doc, external_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert len(unique_documents) == 3 + mock_load_once.assert_called_once_with("image-key") + text_rerank_call_args = mock_text_rerank.call_args.args + assert len(text_rerank_call_args[1]) == 3 + + def test_fetch_multimodal_rerank_skips_missing_image_upload(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE}, + provider="dify", + ) + query = Mock() + query.where.return_value.first.return_value = None + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [image_doc] + docs_arg = mock_text_rerank.call_args.args[1] + assert len(docs_arg) == 1 + + def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance): + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + query_chain = Mock() + query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key") + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="text-content", score=0.77)], + ) + mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="query-upload-id", + documents=[text_doc], + score_threshold=0.2, + top_n=2, + user="user-1", + query_type=QueryType.IMAGE_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [text_doc] + invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs + assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE + assert invoke_kwargs["docs"][0]["content"] == "text-content" + assert invoke_kwargs["user"] == "user-1" + + def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): + query_chain = Mock() + query_chain.where.return_value.first.return_value = None + + with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain): + with pytest.raises(ValueError, match="Upload file not found for query"): + rerank_runner.fetch_multimodal_rerank( + query="missing-upload-id", + documents=[], + query_type=QueryType.IMAGE_QUERY, + ) + + def test_fetch_multimodal_rerank_rejects_unsupported_query_type(self, rerank_runner): + with pytest.raises(ValueError, match="is not supported"): + rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[], + query_type="unsupported_query_type", + ) + + class TestWeightRerankRunner: """Unit tests for WeightRerankRunner. @@ -512,34 +710,39 @@ class TestWeightRerankRunner: - TF-IDF scores are calculated correctly - Cosine similarity is computed for keyword vectors """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - - # Mock keyword extraction with specific keywords + keyword_map = { + "python programming": ["python", "programming"], + "Python is a programming language": ["python", "programming", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.side_effect = [ - ["python", "programming"], # query - ["python", "programming", "language"], # doc1 - ["javascript", "web"], # doc2 - ["java", "programming"], # doc3 - ] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("python programming", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "python programming", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="python programming", documents=sample_documents_with_vectors) - # Assert: Keywords are extracted and scores are calculated - assert len(result) == 3 - # Document 1 should have highest keyword score (matches both query terms) - # Document 3 should have medium score (matches one term) - # Document 2 should have lowest score (matches no terms) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_vector_score_calculation( self, @@ -556,30 +759,42 @@ class TestWeightRerankRunner: - Cosine similarity is calculated with document vectors - Vector scores are properly normalized """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test query": ["test"], + "Python is a programming language": ["python"], + "JavaScript for web development": ["javascript"], + "Java object-oriented programming": ["java"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding model mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance - # Mock cache embedding with specific query vector mock_cache_instance = MagicMock() query_vector = [0.2, 0.3, 0.4, 0.5] mock_cache_instance.embed_query.return_value = query_vector mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test query", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test query", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test query", documents=sample_documents_with_vectors) - # Assert: Vector scores are calculated - assert len(result) == 3 - # Verify cosine similarity was computed (doc2 vector is closest to query vector) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_score_threshold_filtering_weighted( self, @@ -742,28 +957,40 @@ class TestWeightRerankRunner: - Keyword weight (0.4) is applied to keyword scores - Combined score is the sum of weighted components """ - # Arrange: Create runner with known weights runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Python is a programming language": ["python", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test", documents=sample_documents_with_vectors) - # Assert: Scores are combined with weights - # Score = 0.6 * vector_score + 0.4 * keyword_score - assert len(result) == 3 - assert all("score" in doc.metadata for doc in result) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_existing_vector_score_in_metadata( self, @@ -778,7 +1005,6 @@ class TestWeightRerankRunner: - If document already has a score in metadata, it's used - Cosine similarity calculation is skipped for such documents """ - # Arrange: Documents with pre-existing scores documents = [ Document( page_content="Content with existing score", @@ -790,24 +1016,29 @@ class TestWeightRerankRunner: runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Content with existing score": ["test"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", documents) + vector_scores = runner._calculate_cosine("tenant123", "test", documents, weights_config.vector_setting) + expected_score = 0.6 * vector_scores[0] + 0.4 * query_scores[0] + result = runner.run(query="test", documents=documents) - # Assert: Existing score is used in calculation assert len(result) == 1 - # The final score should incorporate the existing score (0.95) with vector weight (0.6) + assert result[0].metadata["doc_id"] == "doc1" + assert result[0].metadata["score"] == pytest.approx(expected_score, rel=1e-6) class TestRerankRunnerFactory: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index ca08cb0591..b90c4935af 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,80 +1,41 @@ -""" -Unit tests for dataset retrieval functionality. - -This module provides comprehensive test coverage for the RetrievalService class, -which is responsible for retrieving relevant documents from datasets using various -search strategies. - -Core Retrieval Mechanisms Tested: -================================== -1. **Vector Search (Semantic Search)** - - Uses embedding vectors to find semantically similar documents - - Supports score thresholds and top-k limiting - - Can filter by document IDs and metadata - -2. **Keyword Search** - - Traditional text-based search using keyword matching - - Handles special characters and query escaping - - Supports document filtering - -3. **Full-Text Search** - - BM25-based full-text search for text matching - - Used in hybrid search scenarios - -4. **Hybrid Search** - - Combines vector and full-text search results - - Implements deduplication to avoid duplicate chunks - - Uses DataPostProcessor for score merging with configurable weights - -5. **Score Merging Algorithms** - - Deduplication based on doc_id - - Retains higher-scoring duplicates - - Supports weighted score combination - -6. **Metadata Filtering** - - Filters documents based on metadata conditions - - Supports document ID filtering - -Test Architecture: -================== -- **Fixtures**: Provide reusable mock objects (datasets, documents, Flask app) -- **Mocking Strategy**: Mock at the method level (embedding_search, keyword_search, etc.) - rather than at the class level to properly simulate the ThreadPoolExecutor behavior -- **Pattern**: All tests follow Arrange-Act-Assert (AAA) pattern -- **Isolation**: Each test is independent and doesn't rely on external state - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py -v - - # Run a specific test class - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::TestRetrievalService -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::\ -TestRetrievalService::test_vector_search_basic -v - -Notes: -====== -- The RetrievalService uses ThreadPoolExecutor for concurrent search operations -- Tests mock the individual search methods to avoid threading complexity -- All mocked search methods modify the all_documents list in-place -- Score thresholds and top-k limits are enforced by the search methods -""" - +import threading +from contextlib import contextmanager, nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest +from flask import Flask, current_app +from sqlalchemy import column +from core.app.app_config.entities import ( + Condition as AppCondition, +) +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, +) +from core.app.app_config.entities import ( + MetadataFilteringCondition as AppMetadataFilteringCondition, +) +from core.app.app_config.entities import ( + ModelConfig as AppModelConfig, +) +from core.app.app_config.entities import ModelConfig as WorkflowModelConfig +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.agent_entities import PlanningStrategy +from core.entities.model_entities import ModelStatus from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.nodes.knowledge_retrieval import exc +from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest from models.dataset import Dataset # ==================== Helper Functions ==================== @@ -2013,3 +1974,3091 @@ class TestDocumentModel: assert doc1 == doc2 assert doc1 != doc3 + + +# ==================== Helper Functions ==================== + + +def create_mock_dataset_methods( + dataset_id: str | None = None, + tenant_id: str | None = None, + provider: str = "dify", + indexing_technique: str = "high_quality", + available_document_count: int = 10, +) -> Mock: + """ + Create a mock Dataset object for testing. + + Args: + dataset_id: Unique identifier for the dataset + tenant_id: Tenant ID for the dataset + provider: Provider type ("dify" or "external") + indexing_technique: Indexing technique ("high_quality" or "economy") + available_document_count: Number of available documents + + Returns: + Mock: A properly configured Dataset mock + """ + dataset = Mock(spec=Dataset) + dataset.id = dataset_id or str(uuid4()) + dataset.tenant_id = tenant_id or str(uuid4()) + dataset.name = "test_dataset" + dataset.provider = provider + dataset.indexing_technique = indexing_technique + dataset.available_document_count = available_document_count + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.retrieval_model = { + "search_method": "semantic_search", + "reranking_enable": False, + "top_k": 4, + "score_threshold_enabled": False, + } + return dataset + + +def create_mock_document_methods( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +# ==================== Test _check_knowledge_rate_limit ==================== + + +class TestCheckKnowledgeRateLimit: + """ + Test suite for _check_knowledge_rate_limit method. + + The _check_knowledge_rate_limit method validates whether a tenant has + exceeded their knowledge retrieval rate limit. This is important for: + - Preventing abuse of the knowledge retrieval system + - Enforcing subscription plan limits + - Tracking usage for billing purposes + + Test Cases: + ============ + 1. Rate limit disabled - no exception raised + 2. Rate limit enabled but not exceeded - no exception raised + 3. Rate limit enabled and exceeded - RateLimitExceededError raised + 4. Redis operations are performed correctly + 5. RateLimitLog is created when limit is exceeded + """ + + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service): + """ + Test that when rate limit is disabled, no exception is raised. + + This test verifies the behavior when the tenant's subscription + does not have rate limiting enabled. + + Verifies: + - FeatureService.get_knowledge_rate_limit is called + - No Redis operations are performed + - No exception is raised + - Retrieval proceeds normally + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit disabled + mock_limit = Mock() + mock_limit.enabled = False + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify FeatureService was called + mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id) + + # Verify no Redis operations were performed + assert not mock_redis.zadd.called + assert not mock_redis.zremrangebyscore.called + assert not mock_redis.zcard.called + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory): + """ + Test that when rate limit is enabled but not exceeded, no exception is raised. + + This test simulates a tenant making requests within their rate limit. + The Redis sorted set stores timestamps of recent requests, and old + requests (older than 60 seconds) are removed. + + Verifies: + - Redis zadd is called to track the request + - Redis zremrangebyscore removes old entries + - Redis zcard returns count within limit + - No exception is raised + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 # Current time in milliseconds + mock_time.time.return_value = current_time / 1000 # Return seconds + mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds + + # Mock Redis operations + # zcard returns 50 (within limit of 100) + mock_redis.zcard.return_value = 50 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify Redis operations + expected_key = f"rate_limit_{tenant_id}" + mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time}) + mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000) + mock_redis.zcard.assert_called_once_with(expected_key) + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_exceeded_raises_exception( + self, mock_time, mock_redis, mock_feature_service, mock_session_factory + ): + """ + Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised. + + This test simulates a tenant exceeding their rate limit. When the count + of recent requests exceeds the limit, an exception should be raised and + a RateLimitLog should be created. + + Verifies: + - Redis zcard returns count exceeding limit + - RateLimitExceededError is raised with correct message + - RateLimitLog is created in database + - Session operations are performed correctly + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 + mock_time.time.return_value = current_time / 1000 + + # Mock Redis operations - return count exceeding limit + mock_redis.zcard.return_value = 150 # Exceeds limit of 100 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert + with pytest.raises(exc.RateLimitExceededError) as exc_info: + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify exception message + assert "knowledge base request rate limit" in str(exc_info.value) + + # Verify RateLimitLog was created + mock_session.add.assert_called_once() + added_log = mock_session.add.call_args[0][0] + assert added_log.tenant_id == tenant_id + assert added_log.subscription_plan == "professional" + assert added_log.operation == "knowledge" + + +# ==================== Test _get_available_datasets ==================== + + +class TestGetAvailableDatasets: + """ + Test suite for _get_available_datasets method. + + The _get_available_datasets method retrieves datasets that are available + for retrieval. A dataset is considered available if: + - It belongs to the specified tenant + - It's in the list of requested dataset_ids + - It has at least one completed, enabled, non-archived document OR + - It's an external provider dataset + + Note: Due to SQLAlchemy subquery complexity, full testing is done in + integration tests. Unit tests here verify basic behavior. + """ + + def test_method_exists_and_has_correct_signature(self): + """ + Test that the method exists and has the correct signature. + + Verifies: + - Method exists on DatasetRetrieval class + - Accepts tenant_id and dataset_ids parameters + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + + # Assert - method exists + assert hasattr(dataset_retrieval, "_get_available_datasets") + # Assert - method is callable + assert callable(dataset_retrieval._get_available_datasets) + + +# ==================== Test knowledge_retrieval ==================== + + +class TestDatasetRetrievalKnowledgeRetrieval: + """ + Test suite for knowledge_retrieval method. + + The knowledge_retrieval method is the main entry point for retrieving + knowledge from datasets. It orchestrates the entire retrieval process: + 1. Checks rate limits + 2. Gets available datasets + 3. Applies metadata filtering if enabled + 4. Performs retrieval (single or multiple mode) + 5. Formats and returns results + + Test Cases: + ============ + 1. Single mode retrieval + 2. Multiple mode retrieval + 3. Metadata filtering disabled + 4. Metadata filtering automatic + 5. Metadata filtering manual + 6. External documents handling + 7. Dify documents handling + 8. Empty results handling + 9. Rate limit exceeded + 10. No available datasets + """ + + def test_knowledge_retrieval_single_mode_basic(self): + """ + Test knowledge_retrieval in single retrieval mode - basic check. + + Note: Full single mode testing requires complex model mocking and + is better suited for integration tests. This test verifies the + method accepts single mode requests. + + Verifies: + - Method can accept single mode request + - Request parameters are correctly structured + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + model_mode="chat", + completion_params={"temperature": 0.7}, + ) + + # Assert - request is properly structured + assert request.retrieval_mode == "single" + assert request.model_provider == "openai" + assert request.model_name == "gpt-4" + assert request.model_mode == "chat" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor): + """ + Test knowledge_retrieval in multiple retrieval mode. + + In multiple mode, retrieval is performed across all datasets and + results are combined and reranked. + + Verifies: + - Rate limit is checked + - Available datasets are retrieved + - Multiple retrieval is performed + - Results are combined and reranked + - Results are formatted correctly + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id1 = str(uuid4()) + dataset_id2 = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id1, dataset_id2], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + score_threshold=0.7, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets + mock_dataset1 = create_mock_dataset_methods(dataset_id=dataset_id1, tenant_id=tenant_id) + mock_dataset2 = create_mock_dataset_methods(dataset_id=dataset_id2, tenant_id=tenant_id) + with patch.object( + dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2] + ): + # Mock get_metadata_filter_condition + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return documents + doc1 = create_mock_document_methods("Python is great", "doc1", score=0.9) + doc2 = create_mock_document_methods("Python is awesome", "doc2", score=0.8) + with patch.object( + dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2] + ) as mock_multiple_retrieve: + # Mock format_retrieval_documents + mock_record = Mock() + mock_record.segment = Mock() + mock_record.segment.dataset_id = dataset_id1 + mock_record.segment.document_id = str(uuid4()) + mock_record.segment.index_node_hash = "hash123" + mock_record.segment.hit_count = 5 + mock_record.segment.word_count = 100 + mock_record.segment.position = 1 + mock_record.segment.get_sign_content.return_value = "Python is great" + mock_record.segment.answer = None + mock_record.score = 0.9 + mock_record.child_chunks = [] + mock_record.summary = None + mock_record.files = None + + mock_retrieval_service = Mock() + mock_retrieval_service.format_retrieval_documents.return_value = [mock_record] + + with patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService", + return_value=mock_retrieval_service, + ): + # Mock database queries + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + mock_dataset_from_db = Mock() + mock_dataset_from_db.id = dataset_id1 + mock_dataset_from_db.name = "test_dataset" + + mock_document = Mock() + mock_document.id = str(uuid4()) + mock_document.name = "test_doc" + mock_document.data_source_type = "upload_file" + mock_document.doc_metadata = {} + + mock_session.query.return_value.filter.return_value.all.return_value = [ + mock_dataset_from_db + ] + mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( + [mock_dataset_from_db, mock_document] + ) + + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + mock_multiple_retrieve.assert_called_once() + + def test_knowledge_retrieval_metadata_filtering_disabled(self): + """ + Test knowledge_retrieval with metadata filtering disabled. + + When metadata filtering is disabled, get_metadata_filter_condition is + NOT called (the method checks metadata_filtering_mode != "disabled"). + + Verifies: + - get_metadata_filter_condition is NOT called when mode is "disabled" + - Retrieval proceeds without metadata filters + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + metadata_filtering_mode="disabled", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + # Mock get_metadata_filter_condition - should NOT be called when disabled + with patch.object( + dataset_retrieval, + "get_metadata_filter_condition", + return_value=(None, None), + ) as mock_get_metadata: + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + # get_metadata_filter_condition should NOT be called when mode is "disabled" + mock_get_metadata.assert_not_called() + + def test_knowledge_retrieval_with_external_documents(self): + """ + Test knowledge_retrieval with external documents. + + External documents come from external knowledge bases and should + be formatted differently than Dify documents. + + Verifies: + - External documents are handled correctly + - Provider is set to "external" + - Metadata includes external-specific fields + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id, provider="external") + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Create external document + external_doc = create_mock_document_methods( + "External knowledge", + "doc1", + score=0.9, + provider="external", + additional_metadata={ + "dataset_id": dataset_id, + "dataset_name": "external_kb", + "document_id": "ext_doc1", + "title": "External Document", + }, + ) + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + if result: + assert result[0].metadata.data_source_type == "external" + + def test_knowledge_retrieval_empty_results(self): + """ + Test knowledge_retrieval when no documents are found. + + Verifies: + - Empty list is returned + - No errors are raised + - All dependencies are still called + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return empty list + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_rate_limit_exceeded(self): + """ + Test knowledge_retrieval when rate limit is exceeded. + + Verifies: + - RateLimitExceededError is raised + - No further processing occurs + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit to raise exception + with patch.object( + dataset_retrieval, + "_check_knowledge_rate_limit", + side_effect=exc.RateLimitExceededError("Rate limit exceeded"), + ): + # Act & Assert + with pytest.raises(exc.RateLimitExceededError): + dataset_retrieval.knowledge_retrieval(request) + + def test_knowledge_retrieval_no_available_datasets(self): + """ + Test knowledge_retrieval when no datasets are available. + + Verifies: + - Empty list is returned + - No retrieval is attempted + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets to return empty list + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self): + """ + Test that knowledge_retrieval processes multiple documents with different scores. + + Note: Full sorting and position testing requires complex SQLAlchemy mocking + which is better suited for integration tests. This test verifies documents + with different scores can be created and have their metadata. + + Verifies: + - Documents can be created with different scores + - Score metadata is properly set + """ + # Create documents with different scores + doc1 = create_mock_document_methods("Low score", "doc1", score=0.6) + doc2 = create_mock_document_methods("High score", "doc2", score=0.95) + doc3 = create_mock_document_methods("Medium score", "doc3", score=0.8) + + # Assert - each document has the correct score + assert doc1.metadata["score"] == 0.6 + assert doc2.metadata["score"] == 0.95 + assert doc3.metadata["score"] == 0.8 + + # Assert - documents are correctly sorted (not the retrieval result, just the list) + unsorted = [doc1, doc2, doc3] + sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True) + assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6] + + +class TestProcessMetadataFilterFunc: + """ + Comprehensive test suite for process_metadata_filter_func method. + + This test class validates all metadata filtering conditions supported by + the DatasetRetrieval class, including string operations, numeric comparisons, + null checks, and list operations. + + Method Signature: + ================== + def process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + ) -> list: + + The method builds SQLAlchemy filter expressions by: + 1. Validating value is not None (except for empty/not empty conditions) + 2. Using DatasetDocument.doc_metadata JSON field operations + 3. Adding appropriate SQLAlchemy expressions to the filters list + 4. Returning the updated filters list + + Mocking Strategy: + ================== + - Mock DatasetDocument.doc_metadata to avoid database dependencies + - Verify filter expressions are created correctly + - Test with various data types (str, int, float, list) + """ + + @pytest.fixture + def retrieval(self): + """ + Create a DatasetRetrieval instance for testing. + + Returns: + DatasetRetrieval: Instance to test process_metadata_filter_func + """ + return DatasetRetrieval() + + @pytest.fixture + def mock_doc_metadata(self): + """ + Mock the DatasetDocument.doc_metadata JSON field. + + The method uses DatasetDocument.doc_metadata[metadata_name] to access + JSON fields. We mock this to avoid database dependencies. + + Returns: + Mock: Mocked doc_metadata attribute + """ + mock_metadata_field = MagicMock() + + # Create mock for string access + mock_string_access = MagicMock() + mock_string_access.like = MagicMock() + mock_string_access.notlike = MagicMock() + mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_string_access.in_ = MagicMock(return_value=MagicMock()) + + # Create mock for float access (for numeric comparisons) + mock_float_access = MagicMock() + mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__le__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) + + # Create mock for null checks + mock_null_access = MagicMock() + mock_null_access.is_ = MagicMock(return_value=MagicMock()) + mock_null_access.isnot = MagicMock(return_value=MagicMock()) + + # Setup __getitem__ to return appropriate mock based on usage + def getitem_side_effect(name): + if name in ["author", "title", "category"]: + return mock_string_access + elif name in ["year", "price", "rating"]: + return mock_float_access + else: + return mock_string_access + + mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) + mock_metadata_field.as_string.return_value = mock_string_access + mock_metadata_field.as_float.return_value = mock_float_access + mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ + mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot + + return mock_metadata_field + + # ==================== String Condition Tests ==================== + + def test_contains_condition_string_value(self, retrieval): + """ + Test 'contains' condition with string value. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value% syntax + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "John" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_contains_condition(self, retrieval): + """ + Test 'not contains' condition. + + Verifies: + - Filters list is populated with NOT LIKE expression + - Pattern matching uses %value% syntax with negation + """ + filters = [] + sequence = 0 + condition = "not contains" + metadata_name = "title" + value = "banned" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_start_with_condition(self, retrieval): + """ + Test 'start with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses value% syntax + """ + filters = [] + sequence = 0 + condition = "start with" + metadata_name = "category" + value = "tech" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_end_with_condition(self, retrieval): + """ + Test 'end with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value syntax + """ + filters = [] + sequence = 0 + condition = "end with" + metadata_name = "filename" + value = ".pdf" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Equality Condition Tests ==================== + + def test_is_condition_with_string_value(self, retrieval): + """ + Test 'is' (=) condition with string value. + + Verifies: + - Filters list is populated with equality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = "Jane Doe" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_equals_condition_with_string_value(self, retrieval): + """ + Test '=' condition with string value. + + Verifies: + - Same behavior as 'is' condition + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "=" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_int_value(self, retrieval): + """ + Test 'is' condition with integer value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_float_value(self, retrieval): + """ + Test 'is' condition with float value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "price" + value = 19.99 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_string_value(self, retrieval): + """ + Test 'is not' (≠) condition with string value. + + Verifies: + - Filters list is populated with inequality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "author" + value = "Unknown" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_equals_condition(self, retrieval): + """ + Test '≠' condition with string value. + + Verifies: + - Same behavior as 'is not' condition + - Inequality expression is used + """ + filters = [] + sequence = 0 + condition = "≠" + metadata_name = "category" + value = "archived" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_numeric_value(self, retrieval): + """ + Test 'is not' condition with numeric value. + + Verifies: + - Numeric inequality comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Null Condition Tests ==================== + + def test_empty_condition(self, retrieval): + """ + Test 'empty' condition (null check). + + Verifies: + - Filters list is populated with IS NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "empty" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_empty_condition(self, retrieval): + """ + Test 'not empty' condition (not null check). + + Verifies: + - Filters list is populated with IS NOT NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "not empty" + metadata_name = "description" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Numeric Comparison Tests ==================== + + def test_before_condition(self, retrieval): + """ + Test 'before' (<) condition. + + Verifies: + - Filters list is populated with less than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "before" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_condition(self, retrieval): + """ + Test '<' condition. + + Verifies: + - Same behavior as 'before' condition + - Less than expression is used + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "price" + value = 100.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_after_condition(self, retrieval): + """ + Test 'after' (>) condition. + + Verifies: + - Filters list is populated with greater than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "after" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_condition(self, retrieval): + """ + Test '>' condition. + + Verifies: + - Same behavior as 'after' condition + - Greater than expression is used + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≤' condition. + + Verifies: + - Filters list is populated with less than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≤" + metadata_name = "price" + value = 50.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_ascii(self, retrieval): + """ + Test '<=' condition. + + Verifies: + - Same behavior as '≤' condition + - Less than or equal expression is used + """ + filters = [] + sequence = 0 + condition = "<=" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≥' condition. + + Verifies: + - Filters list is populated with greater than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≥" + metadata_name = "rating" + value = 3.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_ascii(self, retrieval): + """ + Test '>=' condition. + + Verifies: + - Same behavior as '≥' condition + - Greater than or equal expression is used + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== List/In Condition Tests ==================== + + def test_in_condition_with_comma_separated_string(self, retrieval): + """ + Test 'in' condition with comma-separated string value. + + Verifies: + - String is split into list + - Whitespace is trimmed from each value + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "tech, science, AI " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_list_value(self, retrieval): + """ + Test 'in' condition with list value. + + Verifies: + - List is processed correctly + - None values are filtered out + - IN expression is created with valid values + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "tags" + value = ["python", "javascript", None, "golang"] + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_tuple_value(self, retrieval): + """ + Test 'in' condition with tuple value. + + Verifies: + - Tuple is processed like a list + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = ("tech", "science", "ai") + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_empty_string(self, retrieval): + """ + Test 'in' condition with empty string value. + + Verifies: + - Empty string results in literal(False) filter + - No valid values to match + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + # Verify it's a literal(False) expression + # This is a bit tricky to test without access to the actual expression + + def test_in_condition_with_only_whitespace(self, retrieval): + """ + Test 'in' condition with whitespace-only string value. + + Verifies: + - Whitespace-only string results in literal(False) filter + - All values are stripped and filtered out + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = " , , " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_single_string(self, retrieval): + """ + Test 'in' condition with single non-comma string. + + Verifies: + - Single string is treated as single-item list + - IN expression is created with one value + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Edge Case Tests ==================== + + def test_none_value_with_non_empty_condition(self, retrieval): + """ + Test None value with conditions that require value. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values (except empty/not empty) + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 # No filter added + + def test_none_value_with_equals_condition(self, retrieval): + """ + Test None value with 'is' (=) condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_none_value_with_numeric_condition(self, retrieval): + """ + Test None value with numeric comparison condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "year" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_existing_filters_preserved(self, retrieval): + """ + Test that existing filters are preserved. + + Verifies: + - Existing filters in the list are not removed + - New filters are appended to the list + """ + existing_filter = MagicMock() + filters = [existing_filter] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 2 + assert filters[0] == existing_filter + + def test_multiple_filters_accumulated(self, retrieval): + """ + Test multiple calls to accumulate filters. + + Verifies: + - Each call adds a new filter to the list + - All filters are preserved across calls + """ + filters = [] + + # First filter + retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) + assert len(filters) == 1 + + # Second filter + retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) + assert len(filters) == 2 + + # Third filter + retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) + assert len(filters) == 3 + + def test_unknown_condition(self, retrieval): + """ + Test unknown/unsupported condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for unknown conditions + """ + filters = [] + sequence = 0 + condition = "unknown_condition" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_empty_string_value_with_contains(self, retrieval): + """ + Test empty string value with 'contains' condition. + + Verifies: + - Filter is added even with empty string + - LIKE expression is created + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_special_characters_in_value(self, retrieval): + """ + Test special characters in value string. + + Verifies: + - Special characters are handled in value + - LIKE expression is created correctly + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "title" + value = "C++ & Python's features" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_zero_value_with_numeric_condition(self, retrieval): + """ + Test zero value with numeric comparison condition. + + Verifies: + - Zero is treated as valid value + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "price" + value = 0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_negative_value_with_numeric_condition(self, retrieval): + """ + Test negative value with numeric comparison condition. + + Verifies: + - Negative numbers are handled correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "temperature" + value = -10.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_float_value_with_integer_comparison(self, retrieval): + """ + Test float value with numeric comparison condition. + + Verifies: + - Float values work correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + +class TestKnowledgeRetrievalRegression: + @pytest.fixture + def mock_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = str(uuid4()) + dataset.tenant_id = str(uuid4()) + dataset.name = "test_dataset" + dataset.indexing_technique = "high_quality" + dataset.provider = "dify" + return dataset + + def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): + """ + Repro test for current bug: + reranking runs after `with flask_app.app_context():` exits. + `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, + so we must assert from that list (not from an outer try/except). + """ + dataset_retrieval = DatasetRetrieval() + flask_app = Flask(__name__) + tenant_id = str(uuid4()) + + # second dataset to ensure dataset_count > 1 reranking branch + secondary_dataset = Mock(spec=Dataset) + secondary_dataset.id = str(uuid4()) + secondary_dataset.provider = "dify" + secondary_dataset.indexing_technique = "high_quality" + + # retriever returns 1 doc into internal list (all_documents_item) + document = Document( + page_content="Context aware doc", + metadata={ + "doc_id": "doc1", + "score": 0.95, + "document_id": str(uuid4()), + "dataset_id": mock_dataset.id, + }, + provider="dify", + ) + + def fake_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.append(document) + + called = {"init": 0, "invoke": 0} + + class ContextRequiredPostProcessor: + def __init__(self, *args, **kwargs): + called["init"] += 1 + # will raise RuntimeError if no Flask app context exists + _ = current_app.name + + def invoke(self, *args, **kwargs): + called["invoke"] += 1 + _ = current_app.name + return kwargs.get("documents") or args[1] + + # output list from _multiple_retrieve_thread + all_documents: list[Document] = [] + + # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here + thread_exceptions: list[Exception] = [] + + def target(): + with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): + with patch( + "core.rag.retrieval.dataset_retrieval.DataPostProcessor", + ContextRequiredPostProcessor, + ): + dataset_retrieval._multiple_retrieve_thread( + flask_app=flask_app, + available_datasets=[mock_dataset, secondary_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={ + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-v2", + }, + weights=None, + top_k=3, + score_threshold=0.0, + query="test query", + attachment_id=None, + dataset_count=2, # force reranking branch + thread_exceptions=thread_exceptions, # ✅ key + ) + + t = threading.Thread(target=target) + t.start() + t.join() + + # Ensure reranking branch was actually executed + assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." + + # Current buggy code should record an exception (not raise it) + assert not thread_exceptions, thread_exceptions + + +class _FakeFlaskApp: + def app_context(self): + return nullcontext() + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class TestDatasetRetrievalAdditionalHelpers: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_llm_usage_and_record_usage(self, retrieval: DatasetRetrieval) -> None: + empty_usage = retrieval.llm_usage + assert empty_usage.total_tokens == 0 + + retrieval._record_usage(None) + assert retrieval.llm_usage.total_tokens == 0 + + usage_1 = LLMUsage.from_metadata({"prompt_tokens": 2, "completion_tokens": 3, "total_tokens": 5}) + usage_2 = LLMUsage.from_metadata({"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}) + retrieval._record_usage(usage_1) + retrieval._record_usage(usage_2) + assert retrieval.llm_usage.total_tokens == 10 + + def test_replace_metadata_filter_value(self, retrieval: DatasetRetrieval) -> None: + assert retrieval._replace_metadata_filter_value("plain", {}) == "plain" + replaced = retrieval._replace_metadata_filter_value( + "hello {{name}}\n\t{{missing}}", + {"name": "world"}, + ) + assert replaced == "hello world {{missing}}" + + def test_process_metadata_filter_in_with_scalar_fallback(self) -> None: + filters: list = [] + result = DatasetRetrieval.process_metadata_filter_func( + sequence=0, + condition="in", + metadata_name="category", + value=123, + filters=filters, + ) + assert result is filters + assert len(filters) == 1 + + def test_calculate_vector_score(self, retrieval: DatasetRetrieval) -> None: + doc_high = Document(page_content="a", metadata={"score": 0.9}, provider="dify") + doc_low = Document(page_content="b", metadata={"score": 0.2}, provider="dify") + doc_no_meta = Document(page_content="c", metadata={}, provider="dify") + + filtered = retrieval.calculate_vector_score([doc_low, doc_high, doc_no_meta], top_k=1, score_threshold=0.5) + assert len(filtered) == 1 + assert filtered[0].metadata["score"] == 0.9 + + assert retrieval.calculate_vector_score([doc_low], top_k=2, score_threshold=1.0) == [] + + def test_calculate_keyword_score(self, retrieval: DatasetRetrieval) -> None: + documents = [ + Document(page_content="python language", metadata={"doc_id": "1"}, provider="dify"), + Document(page_content="java language", metadata={"doc_id": "2"}, provider="dify"), + ] + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [ + ["python", "language"], + ["python", "language"], + ["java", "language"], + ] + + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("python language", documents, top_k=1) + + assert len(ranked) == 1 + assert "keywords" in ranked[0].metadata + assert ranked[0].metadata["doc_id"] == "1" + + def test_send_trace_task(self, retrieval: DatasetRetrieval) -> None: + trace_manager = Mock() + retrieval.application_generate_entity = SimpleNamespace(trace_manager=trace_manager) + docs = [Document(page_content="d", metadata={}, provider="dify")] + + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_called_once() + + retrieval.application_generate_entity = None + trace_manager.reset_mock() + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_not_called() + + def test_on_query(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query=None, + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="web", + user_id="u1", + ) + mock_session.add_all.assert_not_called() + + retrieval._on_query( + query="python", + attachment_ids=["f1"], + dataset_ids=["d1", "d2"], + app_id="a1", + user_from="web", + user_id="u1", + ) + mock_session.add_all.assert_called() + mock_session.commit.assert_called() + + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: + usage = LLMUsage.empty_usage() + chunk_1 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace(message=SimpleNamespace(content="hello "), usage=usage), + ) + chunk_2 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace( + message=SimpleNamespace(content=[SimpleNamespace(data="world")]), + usage=None, + ), + ) + text, returned_usage = retrieval._handle_invoke_result(iter([chunk_1, chunk_2])) + assert text == "hello world" + assert returned_usage == usage + + text_empty, usage_empty = retrieval._handle_invoke_result(iter([])) + assert text_empty == "" + assert usage_empty == LLMUsage.empty_usage() + + def test_get_prompt_template(self, retrieval: DatasetRetrieval) -> None: + model_config_chat = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=["x"], + ) + model_config_completion = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="completion", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + + with patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform") as mock_prompt_transform: + mock_prompt_transform.return_value.get_prompt.return_value = ["prompt"] + prompt_messages, stop = retrieval._get_prompt_template( + model_config=model_config_chat, + mode="chat", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages == ["prompt"] + assert stop == ["x"] + + with patch( + "core.rag.retrieval.dataset_retrieval.METADATA_FILTER_COMPLETION_PROMPT", + "{input_text} {metadata_fields}", + ): + prompt_messages_completion, stop_completion = retrieval._get_prompt_template( + model_config=model_config_completion, + mode="completion", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages_completion == ["prompt"] + assert stop_completion == [] + + with pytest.raises(ValueError): + retrieval._get_prompt_template( + model_config=model_config_chat, + mode="unknown-mode", + metadata_fields=[], + query="python", + ) + + def test_fetch_model_config_validation_and_success(self, retrieval: DatasetRetrieval) -> None: + with pytest.raises(ValueError, match="single_retrieval_config is required"): + retrieval._fetch_model_config("tenant-1", None) # type: ignore[arg-type] + + model_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat", completion_params={"stop": ["END"]}) + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance = Mock() + model_instance.model_type_instance.get_model_schema.return_value = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance + mock_cfg_entity.return_value = SimpleNamespace( + provider="openai", + model="gpt", + stop=["END"], + parameters={"temperature": 0.1}, + ) + + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model = SimpleNamespace(status=ModelStatus.NO_CONFIGURE) + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = provider_model + with pytest.raises(ValueError, match="credentials is not initialized"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.NO_PERMISSION + with pytest.raises(ValueError, match="currently not support"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.QUOTA_EXCEEDED + with pytest.raises(ValueError, match="quota exceeded"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.ACTIVE + bad_mode_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat") + bad_mode_cfg.mode = None # type: ignore[assignment] + with pytest.raises(ValueError, match="LLM mode is required"): + retrieval._fetch_model_config("tenant-1", bad_mode_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = Mock() + model_cfg_success = AppModelConfig( + provider="openai", + name="gpt", + mode="chat", + completion_params={"temperature": 0.1, "stop": ["END"]}, + ) + _, config = retrieval._fetch_model_config("tenant-1", model_cfg_success) + assert config.provider == "openai" + assert config.model == "gpt" + assert config.stop == ["END"] + assert "stop" not in config.parameters + + def test_automatic_metadata_filter_func(self, retrieval: DatasetRetrieval) -> None: + metadata_field = SimpleNamespace(name="author") + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([Mock()]) + model_config = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + session_scalars = Mock() + session_scalars.all.return_value = [metadata_field] + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, model_config)), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_handle_invoke_result", return_value=('{"metadata_map":[]}', usage)), + patch("core.rag.retrieval.dataset_retrieval.parse_and_check_json_markdown") as mock_parse, + patch.object(retrieval, "_record_usage") as mock_record_usage, + ): + mock_parse.return_value = { + "metadata_map": [ + { + "metadata_field_name": "author", + "metadata_field_value": "Alice", + "comparison_operator": "contains", + }, + { + "metadata_field_name": "ignored", + "metadata_field_value": "value", + "comparison_operator": "contains", + }, + ] + } + result = retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + assert result == [{"metadata_name": "author", "value": "Alice", "condition": "contains"}] + mock_record_usage.assert_called_once_with(usage) + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", side_effect=RuntimeError("boom")), + ): + with pytest.raises(RuntimeError, match="boom"): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None: + db_query = Mock() + db_query.where.return_value = db_query + db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="disabled", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + assert mapping is None + assert condition is None + + automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query), + patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters), + ): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="automatic", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=AppMetadataFilteringCondition(logical_operator="or", conditions=[]), + inputs={}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.logical_operator == "or" + + manual_conditions = AppMetadataFilteringCondition( + logical_operator="and", + conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")], + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="manual", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=manual_conditions, + inputs={"name": "Alice"}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.conditions[0].value == "Alice" + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with pytest.raises(ValueError, match="Invalid metadata filtering mode"): + retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="unsupported", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + + def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: + session = Mock() + subquery_query = Mock() + subquery_query.where.return_value = subquery_query + subquery_query.group_by.return_value = subquery_query + subquery_query.having.return_value = subquery_query + subquery_query.subquery.return_value = SimpleNamespace( + c=SimpleNamespace( + dataset_id=column("dataset_id"), available_document_count=column("available_document_count") + ) + ) + + dataset_query = Mock() + dataset_query.outerjoin.return_value = dataset_query + dataset_query.where.return_value = dataset_query + dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.query.side_effect = [subquery_query, dataset_query] + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx): + available = retrieval._get_available_datasets("tenant-1", ["d1", "d2"]) + + assert [dataset.id for dataset in available] == ["d1", "d2"] + + def test_check_knowledge_rate_limit(self, retrieval: DatasetRetrieval) -> None: + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=2, subscription_plan="pro") + mock_redis.zcard.return_value = 1 + retrieval._check_knowledge_rate_limit("tenant-1") + mock_redis.zadd.assert_called_once() + + session = Mock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=1, subscription_plan="pro") + mock_redis.zcard.return_value = 2 + with pytest.raises(exc.RateLimitExceededError): + retrieval._check_knowledge_rate_limit("tenant-1") + session.add.assert_called_once() + + with patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit: + mock_limit.return_value = SimpleNamespace(enabled=False) + retrieval._check_knowledge_rate_limit("tenant-1") + + +def _doc( + provider: str = "dify", + content: str = "content", + score: float = 0.9, + dataset_id: str = "dataset-1", + document_id: str = "document-1", + doc_id: str = "node-1", + extra: dict | None = None, +) -> Document: + metadata = { + "score": score, + "dataset_id": dataset_id, + "document_id": document_id, + "doc_id": doc_id, + } + if extra: + metadata.update(extra) + return Document(page_content=content, metadata=metadata, provider=provider) + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class _JoinDrivenThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._started = False + self._alive = False + + def start(self) -> None: + self._started = True + self._alive = True + + def join(self, timeout=None) -> None: + if self._started and self._alive and self._target: + self._target(**self._kwargs) + self._alive = False + + def is_alive(self) -> bool: + return self._alive + + +@contextmanager +def _timer(): + yield {"cost": 1} + + +class TestKnowledgeRetrievalCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_returns_empty_when_query_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query=None, + retrieval_mode="multiple", + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + assert retrieval.knowledge_retrieval(request) == [] + + def test_raises_when_metadata_model_config_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query="query", + retrieval_mode="multiple", + metadata_filtering_mode="automatic", + metadata_model_config=None, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + with pytest.raises(ValueError, match="metadata_model_config is required"): + retrieval.knowledge_retrieval(request) + + @pytest.mark.parametrize( + ("status", "error_cls"), + [ + (ModelStatus.NO_CONFIGURE, "ModelCredentialsNotInitializedError"), + (ModelStatus.NO_PERMISSION, "ModelNotSupportedError"), + (ModelStatus.QUOTA_EXCEEDED, "ModelQuotaExceededError"), + ], + ) + def test_single_mode_raises_for_model_status( + self, + retrieval: DatasetRetrieval, + status: ModelStatus, + error_cls: str, + ) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["dataset-1"], + query="python", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + ) + provider_model_bundle = Mock() + provider_model_bundle.configuration.get_provider_model.return_value = SimpleNamespace(status=status) + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = Mock() + model_instance = SimpleNamespace( + provider_model_bundle=provider_model_bundle, + model_type_instance=model_type_instance, + credentials={}, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + ): + mock_model_manager.return_value.get_model_instance.return_value = model_instance + with pytest.raises(Exception) as exc_info: + retrieval.knowledge_retrieval(request) + assert error_cls in type(exc_info.value).__name__ + + +class TestRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def _build_model_config(self, features: list[ModelFeature] | None = None): + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = SimpleNamespace(features=features or []) + provider_bundle = SimpleNamespace(model_type_instance=model_type_instance) + return ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt-4", + model_schema=Mock(), + mode="chat", + provider_model_bundle=provider_bundle, + credentials={}, + parameters={}, + stop=[], + ) + + def test_returns_none_when_dataset_ids_empty(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=[], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=self._build_model_config(), + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_returns_none_when_model_schema_missing(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + model_config = self._build_model_config() + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = Mock() + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config() + external_doc = _doc( + provider="external", + content="external content", + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External", "dataset_name": "External DS"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[external_doc]), + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert context == "external content" + assert files == [] + + def test_multiple_strategy_with_vision_and_source_details(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=4, + score_threshold=0.1, + rerank_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + reranking_enabled=True, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config(features=[ModelFeature.TOOL_CALL]) + external_doc = _doc( + provider="external", + content="external body", + score=0.8, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External Title", "dataset_name": "External DS"}, + ) + dify_doc = _doc( + provider="dify", + content="dify body", + score=0.9, + dataset_id="d1", + document_id="doc-1", + doc_id="node-1", + ) + record = SimpleNamespace( + segment=SimpleNamespace( + id="segment-1", + dataset_id="d1", + document_id="doc-1", + tenant_id="tenant-1", + hit_count=3, + word_count=11, + position=1, + index_node_hash="hash-1", + content="segment content", + answer="segment answer", + get_sign_content=lambda: "segment content", + ), + score=0.9, + summary="short summary", + files=None, + ) + dataset_item = SimpleNamespace(id="d1", name="Dataset One") + document_item = SimpleNamespace( + id="doc-1", + name="Document One", + data_source_type="upload_file", + doc_metadata={"lang": "en"}, + ) + upload_file = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="https://example.com/img.png", + size=123, + key="k1", + ) + execute_attachments = SimpleNamespace(all=lambda: [(SimpleNamespace(), upload_file)]) + execute_docs = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [document_item])) + execute_datasets = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [dataset_item])) + hit_callback = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.format_retrieval_documents", + return_value=[record], + ), + patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), + patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.DEBUGGER, + show_retrieve_source=True, + hit_callback=hit_callback, + message_id="m1", + vision_enabled=True, + ) + + assert "short summary" in (context or "") + assert "question:segment content answer:segment answer" in (context or "") + assert len(files or []) == 1 + hit_callback.return_retriever_resource_info.assert_called_once() + + +class TestSingleAndMultipleRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_single_retrieve_external_path(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="External DS", + description=None, + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end") as mock_end, + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + mock_external.return_value = [ + {"content": "ext result", "metadata": {"k": "v"}, "score": 0.9, "title": "Ext Doc"} + ] + result = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + message_id="m1", + ) + + assert len(result) == 1 + assert result[0].provider == "external" + mock_end.assert_called_once() + assert retrieval.llm_usage.total_tokens == 2 + + def test_single_retrieve_dify_path_and_filters(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="dataset desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "cohere", "reranking_model_name": "rerank"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + "top_k": 3, + "score_threshold_enabled": True, + "score_threshold": 0.2, + }, + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}) + result_doc = _doc(provider="dify", score=0.7, dataset_id="ds-1", document_id="doc-1", doc_id="node-1") + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.FunctionCallMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[result_doc] + ) as mock_retrieve, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end"), + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.ROUTER, + metadata_filter_document_ids={"ds-1": ["doc-1"]}, + metadata_condition=SimpleNamespace(), + ) + + assert results == [result_doc] + assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc-1"] + assert retrieval.llm_usage.total_tokens == 1 + + def test_single_retrieve_returns_empty_when_no_dataset_selected(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls: + mock_router_cls.return_value.invoke.return_value = (None, LLMUsage.empty_usage()) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[ + SimpleNamespace(id="ds-1", name="DS", description=None), + ], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + ) + assert results == [] + + def test_single_retrieve_respects_metadata_filter_shortcuts(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={"top_k": 2, "search_method": "semantic_search", "reranking_enable": False}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", LLMUsage.empty_usage()) + no_filter = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids=None, + metadata_condition=SimpleNamespace(), + ) + missing_doc_ids = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids={"other-ds": ["x"]}, + metadata_condition=None, + ) + assert no_filter == [] + assert missing_doc_ids == [] + + def test_multiple_retrieve_validation_paths(self, retrieval: DatasetRetrieval) -> None: + assert ( + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=[], + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + == [] + ) + + mixed = [ + SimpleNamespace(id="d1", indexing_technique="high_quality"), + SimpleNamespace(id="d2", indexing_technique="economy"), + ] + with pytest.raises(ValueError, match="different indexing technique"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=mixed, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="weighted_score", + reranking_enable=False, + ) + + high_quality_mismatch = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-b", + embedding_model_provider="provider-b", + ), + ] + with pytest.raises(ValueError, match="different embedding model"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=high_quality_mismatch, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + ) + + def test_multiple_retrieve_threads_and_dedup(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + ] + doc_a = _doc(provider="dify", score=0.8, dataset_id="d1", document_id="doc-1", doc_id="dup") + doc_b = _doc(provider="dify", score=0.7, dataset_id="d2", document_id="doc-2", doc_id="dup") + doc_external = _doc( + provider="external", + score=0.9, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"dataset_name": "Ext", "title": "Ext"}, + ) + app = Flask(__name__) + weights = {"vector_setting": {}} + + def fake_multiple_thread(**kwargs): + if kwargs["query"]: + kwargs["all_documents"].extend([doc_a, doc_b]) + if kwargs["attachment_id"]: + kwargs["all_documents"].append(doc_external) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=fake_multiple_thread), + patch.object(retrieval, "_on_query") as mock_on_query, + patch.object(retrieval, "_on_retrieval_end") as mock_end, + ): + result = retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + weights=weights, + attachment_ids=["att-1"], + message_id="m1", + ) + + assert len(result) == 2 + assert any(doc.provider == "external" for doc in result) + assert weights["vector_setting"]["embedding_provider_name"] == "provider-a" + assert weights["vector_setting"]["embedding_model_name"] == "model-a" + mock_on_query.assert_called_once() + mock_end.assert_called_once() + + def test_multiple_retrieve_propagates_thread_exception(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ) + ] + app = Flask(__name__) + + def failing_thread(**kwargs): + kwargs["thread_exceptions"].append(RuntimeError("thread boom")) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=failing_thread), + ): + with pytest.raises(RuntimeError, match="thread boom"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + + +class TestInternalHooksCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_on_retrieval_end_without_dify_documents(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + with patch.object(retrieval, "_send_trace_task") as mock_trace: + retrieval._on_retrieval_end( + flask_app=app, + documents=[_doc(provider="external")], + message_id="m1", + timer={"cost": 1}, + ) + mock_trace.assert_called_once() + + def test_on_retrieval_end_dify_without_document_ids(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + doc = Document(page_content="x", metadata={"doc_id": "n1"}, provider="dify") + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=[doc], message_id="m1", timer={"cost": 1}) + mock_trace.assert_called_once() + + def test_on_retrieval_end_updates_segments_for_text_and_image(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + docs = [ + _doc(provider="dify", document_id="doc-a", doc_id="idx-a", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-b", doc_id="att-b", extra={"doc_type": DocType.IMAGE}), + _doc(provider="dify", document_id="doc-c", doc_id="idx-c", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-d", doc_id="att-d", extra={"doc_type": DocType.IMAGE}), + ] + dataset_docs = [ + SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-c", doc_form="qa_model"), + SimpleNamespace(id="doc-d", doc_form="qa_model"), + ] + child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] + segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] + bindings = [SimpleNamespace(segment_id="seg-b"), SimpleNamespace(segment_id="seg-d")] + + def _scalars(items): + result = Mock() + result.all.return_value = items + return result + + session = Mock() + session.scalars.side_effect = [ + _scalars(dataset_docs), + _scalars(child_chunks), + _scalars(segments), + _scalars(bindings), + ] + query = Mock() + query.where.return_value = query + session.query.return_value = query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) + + query.update.assert_called_once() + session.commit.assert_called_once() + mock_trace.assert_called_once() + + def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: + flask_app = SimpleNamespace(app_context=lambda: nullcontext()) + all_documents: list[Document] = [] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=None): + assert ( + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="d1", + query="python", + top_k=1, + all_documents=all_documents, + ) + == [] + ) + + external_dataset = SimpleNamespace( + id="ext-ds", + name="External", + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=external_dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + ): + mock_external.return_value = [{"content": "e", "metadata": {}, "score": 0.8, "title": "Ext"}] + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="ext-ds", + query="python", + top_k=1, + all_documents=all_documents, + ) + + economy_dataset = SimpleNamespace( + id="eco-ds", + provider="dify", + retrieval_model={"top_k": 1}, + indexing_technique="economy", + ) + high_dataset = SimpleNamespace( + id="hq-ds", + provider="dify", + retrieval_model={ + "search_method": "semantic_search", + "top_k": 4, + "score_threshold": 0.3, + "score_threshold_enabled": True, + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "x", "reranking_model_name": "y"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + }, + indexing_technique="high_quality", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", side_effect=[economy_dataset, high_dataset] + ), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[_doc(provider="dify")] + ) as mock_retrieve, + ): + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="eco-ds", + query="python", + top_k=2, + all_documents=all_documents, + ) + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="hq-ds", + query="python", + top_k=2, + all_documents=all_documents, + attachment_ids=["att-1"], + ) + assert mock_retrieve.call_count == 2 + assert len(all_documents) >= 3 + + def test_to_dataset_retriever_tool_paths(self, retrieval: DatasetRetrieval) -> None: + dataset_skip_zero = SimpleNamespace(id="d1", provider="dify", available_document_count=0) + dataset_ok_single = SimpleNamespace( + id="d2", + provider="dify", + available_document_count=2, + retrieval_model={"top_k": 2, "score_threshold_enabled": True, "score_threshold": 0.1}, + ) + single_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", + side_effect=[None, dataset_skip_zero, dataset_ok_single], + ), + patch( + "core.tools.utils.dataset_retriever.dataset_retriever_tool.DatasetRetrieverTool.from_dataset", + return_value="single-tool", + ) as mock_single_tool, + ): + single_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["missing", "d1", "d2"], + retrieve_config=single_config, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={"k": "v"}, + ) + + assert single_tools == ["single-tool"] + mock_single_tool.assert_called_once() + + multiple_config_missing = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + reranking_model=None, + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single): + with pytest.raises(ValueError, match="Reranking model is required"): + retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config_missing, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + + multiple_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + top_k=3, + score_threshold=0.2, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single), + patch( + "core.tools.utils.dataset_retriever.dataset_multi_retriever_tool.DatasetMultiRetrieverTool.from_dataset", + return_value="multi-tool", + ) as mock_multi_tool, + ): + multi_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config, + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + assert multi_tools == ["multi-tool"] + mock_multi_tool.assert_called_once() + + def test_additional_small_branches(self, retrieval: DatasetRetrieval) -> None: + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [[], []] + doc = Document(page_content="doc", metadata={"doc_id": "1"}, provider="dify") + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("query", [doc], top_k=1) + assert len(ranked) == 1 + assert ranked[0].metadata.get("score") == 0.0 + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + with pytest.raises(ValueError): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=None, # type: ignore[arg-type] + ) + + session_scalars = Mock() + session_scalars.all.return_value = [SimpleNamespace(name="author")] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(Mock(), Mock())), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_record_usage"), + ): + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("nope") + with patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, Mock())): + assert ( + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=WorkflowModelConfig(provider="openai", name="gpt", mode="chat"), + ) + is None + ) + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelMode", return_value=object()), + patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform"), + ): + with pytest.raises(ValueError, match="not support"): + retrieval._get_prompt_template( + model_config=ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ), + mode="chat", + metadata_fields=[], + query="q", + ) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py deleted file mode 100644 index 07d6e51e4b..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py +++ /dev/null @@ -1,873 +0,0 @@ -""" -Unit tests for DatasetRetrieval.process_metadata_filter_func. - -This module provides comprehensive test coverage for the process_metadata_filter_func -method in the DatasetRetrieval class, which is responsible for building SQLAlchemy -filter expressions based on metadata filtering conditions. - -Conditions Tested: -================== -1. **String Conditions**: contains, not contains, start with, end with -2. **Equality Conditions**: is / =, is not / ≠ -3. **Null Conditions**: empty, not empty -4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >= -5. **List Conditions**: in -6. **Edge Cases**: None values, different data types (str, int, float) - -Test Architecture: -================== -- Direct instantiation of DatasetRetrieval -- Mocking of DatasetDocument model attributes -- Verification of SQLAlchemy filter expressions -- Follows Arrange-Act-Assert (AAA) pattern - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\ -TestProcessMetadataFilterFunc::test_contains_condition -v -""" - -from unittest.mock import MagicMock - -import pytest - -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval - - -class TestProcessMetadataFilterFunc: - """ - Comprehensive test suite for process_metadata_filter_func method. - - This test class validates all metadata filtering conditions supported by - the DatasetRetrieval class, including string operations, numeric comparisons, - null checks, and list operations. - - Method Signature: - ================== - def process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list - ) -> list: - - The method builds SQLAlchemy filter expressions by: - 1. Validating value is not None (except for empty/not empty conditions) - 2. Using DatasetDocument.doc_metadata JSON field operations - 3. Adding appropriate SQLAlchemy expressions to the filters list - 4. Returning the updated filters list - - Mocking Strategy: - ================== - - Mock DatasetDocument.doc_metadata to avoid database dependencies - - Verify filter expressions are created correctly - - Test with various data types (str, int, float, list) - """ - - @pytest.fixture - def retrieval(self): - """ - Create a DatasetRetrieval instance for testing. - - Returns: - DatasetRetrieval: Instance to test process_metadata_filter_func - """ - return DatasetRetrieval() - - @pytest.fixture - def mock_doc_metadata(self): - """ - Mock the DatasetDocument.doc_metadata JSON field. - - The method uses DatasetDocument.doc_metadata[metadata_name] to access - JSON fields. We mock this to avoid database dependencies. - - Returns: - Mock: Mocked doc_metadata attribute - """ - mock_metadata_field = MagicMock() - - # Create mock for string access - mock_string_access = MagicMock() - mock_string_access.like = MagicMock() - mock_string_access.notlike = MagicMock() - mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_string_access.in_ = MagicMock(return_value=MagicMock()) - - # Create mock for float access (for numeric comparisons) - mock_float_access = MagicMock() - mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__le__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) - - # Create mock for null checks - mock_null_access = MagicMock() - mock_null_access.is_ = MagicMock(return_value=MagicMock()) - mock_null_access.isnot = MagicMock(return_value=MagicMock()) - - # Setup __getitem__ to return appropriate mock based on usage - def getitem_side_effect(name): - if name in ["author", "title", "category"]: - return mock_string_access - elif name in ["year", "price", "rating"]: - return mock_float_access - else: - return mock_string_access - - mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) - mock_metadata_field.as_string.return_value = mock_string_access - mock_metadata_field.as_float.return_value = mock_float_access - mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ - mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot - - return mock_metadata_field - - # ==================== String Condition Tests ==================== - - def test_contains_condition_string_value(self, retrieval): - """ - Test 'contains' condition with string value. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value% syntax - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "John" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_contains_condition(self, retrieval): - """ - Test 'not contains' condition. - - Verifies: - - Filters list is populated with NOT LIKE expression - - Pattern matching uses %value% syntax with negation - """ - filters = [] - sequence = 0 - condition = "not contains" - metadata_name = "title" - value = "banned" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_start_with_condition(self, retrieval): - """ - Test 'start with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses value% syntax - """ - filters = [] - sequence = 0 - condition = "start with" - metadata_name = "category" - value = "tech" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_end_with_condition(self, retrieval): - """ - Test 'end with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value syntax - """ - filters = [] - sequence = 0 - condition = "end with" - metadata_name = "filename" - value = ".pdf" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Equality Condition Tests ==================== - - def test_is_condition_with_string_value(self, retrieval): - """ - Test 'is' (=) condition with string value. - - Verifies: - - Filters list is populated with equality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = "Jane Doe" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_equals_condition_with_string_value(self, retrieval): - """ - Test '=' condition with string value. - - Verifies: - - Same behavior as 'is' condition - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "=" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_int_value(self, retrieval): - """ - Test 'is' condition with integer value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_float_value(self, retrieval): - """ - Test 'is' condition with float value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "price" - value = 19.99 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_string_value(self, retrieval): - """ - Test 'is not' (≠) condition with string value. - - Verifies: - - Filters list is populated with inequality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "author" - value = "Unknown" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_equals_condition(self, retrieval): - """ - Test '≠' condition with string value. - - Verifies: - - Same behavior as 'is not' condition - - Inequality expression is used - """ - filters = [] - sequence = 0 - condition = "≠" - metadata_name = "category" - value = "archived" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_numeric_value(self, retrieval): - """ - Test 'is not' condition with numeric value. - - Verifies: - - Numeric inequality comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Null Condition Tests ==================== - - def test_empty_condition(self, retrieval): - """ - Test 'empty' condition (null check). - - Verifies: - - Filters list is populated with IS NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "empty" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_empty_condition(self, retrieval): - """ - Test 'not empty' condition (not null check). - - Verifies: - - Filters list is populated with IS NOT NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "not empty" - metadata_name = "description" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Numeric Comparison Tests ==================== - - def test_before_condition(self, retrieval): - """ - Test 'before' (<) condition. - - Verifies: - - Filters list is populated with less than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "before" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_condition(self, retrieval): - """ - Test '<' condition. - - Verifies: - - Same behavior as 'before' condition - - Less than expression is used - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "price" - value = 100.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_after_condition(self, retrieval): - """ - Test 'after' (>) condition. - - Verifies: - - Filters list is populated with greater than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "after" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_condition(self, retrieval): - """ - Test '>' condition. - - Verifies: - - Same behavior as 'after' condition - - Greater than expression is used - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≤' condition. - - Verifies: - - Filters list is populated with less than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≤" - metadata_name = "price" - value = 50.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_ascii(self, retrieval): - """ - Test '<=' condition. - - Verifies: - - Same behavior as '≤' condition - - Less than or equal expression is used - """ - filters = [] - sequence = 0 - condition = "<=" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≥' condition. - - Verifies: - - Filters list is populated with greater than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≥" - metadata_name = "rating" - value = 3.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_ascii(self, retrieval): - """ - Test '>=' condition. - - Verifies: - - Same behavior as '≥' condition - - Greater than or equal expression is used - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== List/In Condition Tests ==================== - - def test_in_condition_with_comma_separated_string(self, retrieval): - """ - Test 'in' condition with comma-separated string value. - - Verifies: - - String is split into list - - Whitespace is trimmed from each value - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "tech, science, AI " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_list_value(self, retrieval): - """ - Test 'in' condition with list value. - - Verifies: - - List is processed correctly - - None values are filtered out - - IN expression is created with valid values - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "tags" - value = ["python", "javascript", None, "golang"] - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_tuple_value(self, retrieval): - """ - Test 'in' condition with tuple value. - - Verifies: - - Tuple is processed like a list - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = ("tech", "science", "ai") - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_empty_string(self, retrieval): - """ - Test 'in' condition with empty string value. - - Verifies: - - Empty string results in literal(False) filter - - No valid values to match - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - # Verify it's a literal(False) expression - # This is a bit tricky to test without access to the actual expression - - def test_in_condition_with_only_whitespace(self, retrieval): - """ - Test 'in' condition with whitespace-only string value. - - Verifies: - - Whitespace-only string results in literal(False) filter - - All values are stripped and filtered out - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = " , , " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_single_string(self, retrieval): - """ - Test 'in' condition with single non-comma string. - - Verifies: - - Single string is treated as single-item list - - IN expression is created with one value - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Edge Case Tests ==================== - - def test_none_value_with_non_empty_condition(self, retrieval): - """ - Test None value with conditions that require value. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values (except empty/not empty) - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 # No filter added - - def test_none_value_with_equals_condition(self, retrieval): - """ - Test None value with 'is' (=) condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_none_value_with_numeric_condition(self, retrieval): - """ - Test None value with numeric comparison condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "year" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_existing_filters_preserved(self, retrieval): - """ - Test that existing filters are preserved. - - Verifies: - - Existing filters in the list are not removed - - New filters are appended to the list - """ - existing_filter = MagicMock() - filters = [existing_filter] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 2 - assert filters[0] == existing_filter - - def test_multiple_filters_accumulated(self, retrieval): - """ - Test multiple calls to accumulate filters. - - Verifies: - - Each call adds a new filter to the list - - All filters are preserved across calls - """ - filters = [] - - # First filter - retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) - assert len(filters) == 1 - - # Second filter - retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) - assert len(filters) == 2 - - # Third filter - retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) - assert len(filters) == 3 - - def test_unknown_condition(self, retrieval): - """ - Test unknown/unsupported condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for unknown conditions - """ - filters = [] - sequence = 0 - condition = "unknown_condition" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_empty_string_value_with_contains(self, retrieval): - """ - Test empty string value with 'contains' condition. - - Verifies: - - Filter is added even with empty string - - LIKE expression is created - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_special_characters_in_value(self, retrieval): - """ - Test special characters in value string. - - Verifies: - - Special characters are handled in value - - LIKE expression is created correctly - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "title" - value = "C++ & Python's features" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_zero_value_with_numeric_condition(self, retrieval): - """ - Test zero value with numeric comparison condition. - - Verifies: - - Zero is treated as valid value - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "price" - value = 0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_negative_value_with_numeric_condition(self, retrieval): - """ - Test negative value with numeric comparison condition. - - Verifies: - - Negative numbers are handled correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "temperature" - value = -10.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_float_value_with_integer_comparison(self, retrieval): - """ - Test float value with numeric comparison condition. - - Verifies: - - Float values work correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 diff --git a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py deleted file mode 100644 index 5f461d53ae..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py +++ /dev/null @@ -1,113 +0,0 @@ -import threading -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest -from flask import Flask, current_app - -from core.rag.models.document import Document -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from models.dataset import Dataset - - -class TestRetrievalService: - @pytest.fixture - def mock_dataset(self) -> Dataset: - dataset = Mock(spec=Dataset) - dataset.id = str(uuid4()) - dataset.tenant_id = str(uuid4()) - dataset.name = "test_dataset" - dataset.indexing_technique = "high_quality" - dataset.provider = "dify" - return dataset - - def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): - """ - Repro test for current bug: - reranking runs after `with flask_app.app_context():` exits. - `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, - so we must assert from that list (not from an outer try/except). - """ - dataset_retrieval = DatasetRetrieval() - flask_app = Flask(__name__) - tenant_id = str(uuid4()) - - # second dataset to ensure dataset_count > 1 reranking branch - secondary_dataset = Mock(spec=Dataset) - secondary_dataset.id = str(uuid4()) - secondary_dataset.provider = "dify" - secondary_dataset.indexing_technique = "high_quality" - - # retriever returns 1 doc into internal list (all_documents_item) - document = Document( - page_content="Context aware doc", - metadata={ - "doc_id": "doc1", - "score": 0.95, - "document_id": str(uuid4()), - "dataset_id": mock_dataset.id, - }, - provider="dify", - ) - - def fake_retriever( - flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids - ): - all_documents.append(document) - - called = {"init": 0, "invoke": 0} - - class ContextRequiredPostProcessor: - def __init__(self, *args, **kwargs): - called["init"] += 1 - # will raise RuntimeError if no Flask app context exists - _ = current_app.name - - def invoke(self, *args, **kwargs): - called["invoke"] += 1 - _ = current_app.name - return kwargs.get("documents") or args[1] - - # output list from _multiple_retrieve_thread - all_documents: list[Document] = [] - - # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here - thread_exceptions: list[Exception] = [] - - def target(): - with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): - with patch( - "core.rag.retrieval.dataset_retrieval.DataPostProcessor", - ContextRequiredPostProcessor, - ): - dataset_retrieval._multiple_retrieve_thread( - flask_app=flask_app, - available_datasets=[mock_dataset, secondary_dataset], - metadata_condition=None, - metadata_filter_document_ids=None, - all_documents=all_documents, - tenant_id=tenant_id, - reranking_enable=True, - reranking_mode="reranking_model", - reranking_model={ - "reranking_provider_name": "cohere", - "reranking_model_name": "rerank-v2", - }, - weights=None, - top_k=3, - score_threshold=0.0, - query="test query", - attachment_id=None, - dataset_count=2, # force reranking branch - thread_exceptions=thread_exceptions, # ✅ key - ) - - t = threading.Thread(target=target) - t.start() - t.join() - - # Ensure reranking branch was actually executed - assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." - - # Current buggy code should record an exception (not raise it) - assert not thread_exceptions, thread_exceptions diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py new file mode 100644 index 0000000000..cfa9094e12 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -0,0 +1,100 @@ +from unittest.mock import Mock + +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage + + +class TestFunctionCallMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = FunctionCallMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_model_response(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + usage = LLMUsage.empty_usage() + response = Mock() + response.usage = usage + response.message.tool_calls = [Mock(function=Mock())] + response.message.tool_calls[0].function.name = "dataset-2" + model_instance = Mock() + model_instance.invoke_llm.return_value = response + + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id == "dataset-2" + assert returned_usage == usage + model_instance.invoke_llm.assert_called_once() + + def test_invoke_returns_none_when_no_tool_calls(self) -> None: + router = FunctionCallMultiDatasetRouter() + response = Mock() + response.usage = LLMUsage.empty_usage() + response.message.tool_calls = [] + model_instance = Mock() + model_instance.invoke_llm.return_value = response + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == response.usage + + def test_invoke_returns_empty_usage_when_model_raises(self) -> None: + router = FunctionCallMultiDatasetRouter() + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("boom") + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py new file mode 100644 index 0000000000..e429563739 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -0,0 +1,252 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole + + +class TestReactMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = ReactMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = ReactMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_react_invoke(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + tool_1 = Mock(name="dataset-1") + tool_1.name = "dataset-1" + tool_2 = Mock(name="dataset-2") + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", return_value=("dataset-2", usage)) as mock_react: + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + mock_react.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_invoke_handles_react_invoke_errors(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", side_effect=RuntimeError("boom")): + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_react_invoke_returns_action_tool(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "chat" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] + tools[0].name = "dataset-1" + tools[0].description = "desc" + tools[1].name = "dataset-2" + tools[1].description = "desc" + + with ( + patch.object(router, "create_chat_prompt", return_value=[Mock()]) as mock_chat_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object(router, "_invoke_llm", return_value=('{"action":"dataset-2","action_input":{}}', usage)), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactAction("dataset-2", {}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=tools, + user_id="u1", + tenant_id="t1", + ) + + mock_chat_prompt.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_react_invoke_returns_none_for_finish(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "completion" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, "_invoke_llm", return_value=('{"action":"Final Answer","action_input":"done"}', usage) + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert returned_usage == usage + + def test_invoke_llm_and_handle_result(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=usage) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([chunk]) + + with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + text, returned_usage = router._invoke_llm( + completion_param={"temperature": 0.1}, + model_instance=model_instance, + prompt_messages=[Mock()], + stop=["Observation:"], + user_id="u1", + tenant_id="t1", + ) + + assert text == "part" + assert returned_usage == usage + mock_deduct.assert_called_once() + + def test_handle_invoke_result_with_empty_usage(self) -> None: + router = ReactMultiDatasetRouter() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=None) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + + text, usage = router._handle_invoke_result(iter([chunk])) + + assert text == "part" + assert usage == LLMUsage.empty_usage() + + def test_create_chat_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + chat_prompt = router.create_chat_prompt(query="python", tools=[tool_1, tool_2]) + assert len(chat_prompt) == 2 + assert chat_prompt[0].role == PromptMessageRole.SYSTEM + assert chat_prompt[1].role == PromptMessageRole.USER + assert "dataset-1" in chat_prompt[0].text + assert "dataset-2" in chat_prompt[0].text + + def test_create_completion_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + completion_prompt = router.create_completion_prompt(tools=[tool_1, tool_2]) + assert "dataset-1: d1" in completion_prompt.text + assert "dataset-2: d2" in completion_prompt.text + + def test_react_invoke_uses_completion_branch_for_non_chat_mode(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "unknown-mode" + model_config.parameters = {} + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, + "_invoke_llm", + return_value=('{"action":"Final Answer","action_input":"done"}', LLMUsage.empty_usage()), + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + dataset_id, usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py new file mode 100644 index 0000000000..c8fa0ea62f --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py @@ -0,0 +1,69 @@ +import pytest + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser + + +class TestStructuredChatOutputParser: + def test_parse_action_without_action_input(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"some_action"}\n```' + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "some_action" + assert result.tool_input == {} + + def test_parse_json_without_action_key(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"not_action":"search"}\n```' + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) + + def test_parse_returns_action_for_tool_call(self) -> None: + parser = StructuredChatOutputParser() + text = ( + 'Thought: call tool\nAction:\n```json\n{"action":"search_dataset","action_input":{"query":"python"}}\n```' + ) + + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "search_dataset" + assert result.tool_input == {"query": "python"} + assert result.log == text + + def test_parse_returns_finish_for_final_answer(self) -> None: + parser = StructuredChatOutputParser() + text = 'Thought: done\nAction:\n```json\n{"action":"Final Answer","action_input":"final text"}\n```' + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": "final text"} + assert result.log == text + + def test_parse_returns_finish_for_json_array_payload(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n[{"action":"search","action_input":"hello"}]\n```' + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + assert result.log == text + + def test_parse_returns_finish_for_plain_text(self) -> None: + parser = StructuredChatOutputParser() + text = "No structured action block" + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + + def test_parse_raises_value_error_for_invalid_json(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"search","action_input": }\n```' + + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 943a9e5712..976de10d89 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -125,7 +125,11 @@ Run with coverage: - Tests are organized by functionality in classes for better organization """ +import asyncio import string +import sys +import types +from inspect import currentframe from unittest.mock import Mock, patch import pytest @@ -604,6 +608,51 @@ class TestRecursiveCharacterTextSplitter: assert "def hello_world" in combined or "hello_world" in combined +class TestTextSplitterBasePaths: + """Target uncovered base TextSplitter paths.""" + + def test_from_huggingface_tokenizer_success_path(self): + """Cover from_huggingface_tokenizer success branch with mocked transformers.""" + + class _FakePreTrainedTokenizerBase: + pass + + class _FakeTokenizer(_FakePreTrainedTokenizerBase): + def encode(self, text: str): + return [ord(c) for c in text] + + fake_transformers = types.SimpleNamespace(PreTrainedTokenizerBase=_FakePreTrainedTokenizerBase) + with patch.dict(sys.modules, {"transformers": fake_transformers}): + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer=_FakeTokenizer(), + chunk_size=5, + chunk_overlap=1, + ) + + chunks = splitter.split_text("abcdef") + assert chunks + + def test_from_huggingface_tokenizer_import_error(self): + """Cover from_huggingface_tokenizer import-error branch.""" + with patch.dict(sys.modules, {"transformers": None}): + with pytest.raises(ValueError, match="Could not import transformers"): + RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer=object(), chunk_size=5) + + def test_atransform_documents_raises_not_implemented(self): + """Cover atransform_documents NotImplemented branch.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=20, chunk_overlap=5) + with pytest.raises(NotImplementedError): + asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})])) + + def test_merge_splits_logs_warning_for_oversized_total(self): + """Cover logger.warning path in _merge_splits.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1) + with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning: + merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1]) + assert merged + mock_warning.assert_called_once() + + # ============================================================================ # Test TokenTextSplitter # ============================================================================ @@ -662,6 +711,44 @@ class TestTokenTextSplitter: except ImportError: pytest.skip("tiktoken not installed") + def test_initialization_and_split_with_mocked_tiktoken_encoding(self): + """Cover TokenTextSplitter __init__ else-path and split_text logic.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_tiktoken = types.SimpleNamespace(get_encoding=lambda name: _FakeEncoding()) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(encoding_name="gpt2", chunk_size=4, chunk_overlap=1) + result = splitter.split_text("abcdefgh") + + assert result + assert all(isinstance(chunk, str) for chunk in result) + + def test_initialization_with_model_name_uses_encoding_for_model(self): + """Cover TokenTextSplitter model_name init branch.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_encoding = _FakeEncoding() + fake_tiktoken = types.SimpleNamespace( + encoding_for_model=lambda model_name: fake_encoding, + get_encoding=lambda name: _FakeEncoding(), + ) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(model_name="gpt-4", chunk_size=5, chunk_overlap=1) + + assert splitter._tokenizer is fake_encoding + # ============================================================================ # Test EnhanceRecursiveCharacterTextSplitter @@ -731,6 +818,50 @@ class TestEnhanceRecursiveCharacterTextSplitter: assert len(result) > 0 assert all(isinstance(chunk, str) for chunk in result) + def test_from_encoder_internal_token_encoder_paths(self): + """ + Test internal _token_encoder branches by capturing local closure from frame. + + This validates: + - empty texts path + - embedding model path + - GPT2Tokenizer fallback path + - _character_encoder empty-path branch + """ + + class _SpySplitter(EnhanceRecursiveCharacterTextSplitter): + captured_token_encoder = None + captured_character_encoder = None + + def __init__(self, **kwargs): + frame = currentframe() + if frame and frame.f_back: + _SpySplitter.captured_token_encoder = frame.f_back.f_locals.get("_token_encoder") + _SpySplitter.captured_character_encoder = frame.f_back.f_locals.get("_character_encoder") + super().__init__(**kwargs) + + mock_model = Mock() + mock_model.get_text_embedding_num_tokens.return_value = [3, 5] + + _SpySplitter.from_encoder(embedding_model_instance=mock_model, chunk_size=10, chunk_overlap=1) + token_encoder = _SpySplitter.captured_token_encoder + character_encoder = _SpySplitter.captured_character_encoder + + assert token_encoder is not None + assert character_encoder is not None + assert token_encoder([]) == [] + assert token_encoder(["abc", "defgh"]) == [3, 5] + assert character_encoder([]) == [] + + with patch( + "core.rag.splitter.fixed_text_splitter.GPT2Tokenizer.get_num_tokens", + side_effect=lambda text: len(text) + 1, + ): + _SpySplitter.from_encoder(embedding_model_instance=None, chunk_size=10, chunk_overlap=1) + token_encoder_without_model = _SpySplitter.captured_token_encoder + assert token_encoder_without_model is not None + assert token_encoder_without_model(["ab", "cdef"]) == [3, 5] + # ============================================================================ # Test FixedRecursiveCharacterTextSplitter @@ -908,6 +1039,56 @@ class TestFixedRecursiveCharacterTextSplitter: chunks = splitter.split_text(data) assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + def test_recursive_split_keep_separator_and_recursive_fallback(self): + """Cover keep-separator split branch and recursive _split_text fallback.""" + text = "short." + ("x" * 60) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=[".", " ", ""], + chunk_size=10, + chunk_overlap=2, + keep_separator=True, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert any("short." in chunk for chunk in chunks) + assert any(len(chunk) <= 12 for chunk in chunks) + + def test_recursive_split_newline_separator_filtering(self): + """Cover newline-specific empty filtering branch.""" + text = "line1\n\nline2\n\nline3" + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n", ""], + chunk_size=50, + chunk_overlap=5, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert all(chunk != "" for chunk in chunks) + assert "line1" in "".join(chunks) + assert "line2" in "".join(chunks) + assert "line3" in "".join(chunks) + + def test_recursive_split_without_new_separator_appends_long_chunk(self): + """Cover branch where no further separators exist and long split is appended directly.""" + text = "aa\n" + ("b" * 40) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n"], + chunk_size=10, + chunk_overlap=2, + ) + + chunks = splitter.recursive_split_text(text) + + assert "aa" in chunks + assert any(len(chunk) >= 40 for chunk in chunks) + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py new file mode 100644 index 0000000000..c66e50437a --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -0,0 +1,84 @@ +from datetime import datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType +from models import Account, WorkflowRun +from models.enums import WorkflowRunTriggeredFrom + + +def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository: + engine = create_engine("sqlite:///:memory:") + real_session_factory = sessionmaker(bind=engine, expire_on_commit=False) + + user = MagicMock(spec=Account) + user.id = str(uuid4()) + user.current_tenant_id = str(uuid4()) + + repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=real_session_factory, + user=user, + app_id="app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = False + repository._session_factory = MagicMock(return_value=session_context) + return repository + + +def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution: + return WorkflowExecution.new( + id_=execution_id, + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0.0", + graph={"nodes": [], "edges": []}, + inputs={"query": "hello"}, + started_at=started_at, + ) + + +def test_save_uses_execution_started_at_when_record_does_not_exist(): + session = MagicMock() + session.get.return_value = None + repository = _build_repository_with_mocked_session(session) + + started_at = datetime(2026, 1, 1, 12, 0, 0) + execution = _build_execution(execution_id=str(uuid4()), started_at=started_at) + + repository.save(execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == started_at + session.commit.assert_called_once() + + +def test_save_preserves_existing_created_at_when_record_already_exists(): + session = MagicMock() + repository = _build_repository_with_mocked_session(session) + + execution_id = str(uuid4()) + existing_created_at = datetime(2026, 1, 1, 12, 0, 0) + existing_run = WorkflowRun() + existing_run.id = execution_id + existing_run.tenant_id = repository._tenant_id + existing_run.created_at = existing_created_at + session.get.return_value = existing_run + + execution = _build_execution( + execution_id=execution_id, + started_at=datetime(2026, 1, 1, 12, 30, 0), + ) + + repository.save(execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == existing_created_at + session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/trigger/__init__.py b/api/tests/unit_tests/core/trigger/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/conftest.py b/api/tests/unit_tests/core/trigger/conftest.py new file mode 100644 index 0000000000..d9da80a8b7 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/conftest.py @@ -0,0 +1,93 @@ +"""Shared factory helpers for core.trigger test suite.""" + +from __future__ import annotations + +from typing import Any + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.entities.entities import ( + EventEntity, + EventIdentity, + EventParameter, + OAuthSchema, + Subscription, + SubscriptionConstructor, + TriggerProviderEntity, + TriggerProviderIdentity, +) +from core.trigger.provider import PluginTriggerProviderController +from models.provider_ids import TriggerProviderID + +# Valid format for TriggerProviderID: org/plugin/provider +VALID_PROVIDER_ID = "testorg/testplugin/testprovider" + + +def i18n(text: str = "test") -> I18nObject: + return I18nObject(en_US=text, zh_Hans=text) + + +def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity: + return EventEntity( + identity=EventIdentity(author="a", name=name, label=i18n(name)), + description=i18n(name), + parameters=parameters or [], + ) + + +def make_provider_entity( + name: str = "test_provider", + events: list[EventEntity] | None = None, + constructor: SubscriptionConstructor | None = None, + subscription_schema: list[ProviderConfig] | None = None, + icon: str | None = "icon.png", + icon_dark: str | None = None, +) -> TriggerProviderEntity: + return TriggerProviderEntity( + identity=TriggerProviderIdentity( + author="a", + name=name, + label=i18n(name), + description=i18n(name), + icon=icon, + icon_dark=icon_dark, + ), + events=events if events is not None else [make_event()], + subscription_constructor=constructor, + subscription_schema=subscription_schema or [], + ) + + +def make_controller( + entity: TriggerProviderEntity | None = None, + tenant_id: str = "tenant-1", + provider_id: str = VALID_PROVIDER_ID, +) -> PluginTriggerProviderController: + return PluginTriggerProviderController( + entity=entity or make_provider_entity(), + plugin_id="plugin-1", + plugin_unique_identifier="uid-1", + provider_id=TriggerProviderID(provider_id), + tenant_id=tenant_id, + ) + + +def make_subscription(**overrides: Any) -> Subscription: + defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}} + defaults.update(overrides) + return Subscription(**defaults) + + +def make_provider_config( + name: str = "api_key", required: bool = True, config_type: str = "secret-input" +) -> ProviderConfig: + return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required) + + +def make_constructor( + credentials_schema: list[ProviderConfig] | None = None, + oauth_schema: OAuthSchema | None = None, +) -> SubscriptionConstructor: + return SubscriptionConstructor( + parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema + ) diff --git a/api/tests/unit_tests/core/trigger/debug/__init__.py b/api/tests/unit_tests/core/trigger/debug/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py new file mode 100644 index 0000000000..d557c20f5e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py @@ -0,0 +1,93 @@ +""" +Tests for core.trigger.debug.event_bus.TriggerDebugEventBus. + +Covers: Lua-script dispatch/poll with Redis error resilience. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from redis import RedisError + +from core.trigger.debug.event_bus import TriggerDebugEventBus +from core.trigger.debug.events import PluginTriggerDebugEvent + + +class TestDispatch: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_dispatch_count(self, mock_redis): + mock_redis.eval.return_value = 3 + event = MagicMock() + event.model_dump_json.return_value = '{"test": true}' + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 3 + mock_redis.eval.assert_called_once() + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_zero(self, mock_redis): + mock_redis.eval.side_effect = RedisError("connection lost") + event = MagicMock() + event.model_dump_json.return_value = "{}" + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 0 + + +class TestPoll: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_deserialized_event(self, mock_redis): + event_json = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ).model_dump_json() + mock_redis.eval.return_value = event_json + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is not None + assert result.name == "push" + + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_none_when_no_event(self, mock_redis): + mock_redis.eval.return_value = None + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_none(self, mock_redis): + mock_redis.eval.side_effect = RedisError("timeout") + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py new file mode 100644 index 0000000000..331bcd6c25 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -0,0 +1,276 @@ +""" +Tests for core.trigger.debug.event_selectors. + +Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory, +and select_trigger_debug_events orchestrator. +""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.debug.event_selectors import ( + PluginTriggerDebugEventPoller, + ScheduleTriggerDebugEventPoller, + WebhookTriggerDebugEventPoller, + create_event_poller, + select_trigger_debug_events, +) +from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent +from dify_graph.enums import NodeType +from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID + + +def _make_poller_args(node_config: dict | None = None) -> dict: + return { + "tenant_id": "t1", + "user_id": "u1", + "app_id": "a1", + "node_config": node_config or {"data": {}}, + "node_id": "n1", + } + + +def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict: + """Valid node config for TriggerEventNodeData.model_validate.""" + return { + "data": { + "title": "test", + "plugin_id": "org/testplugin", + "provider_id": provider_id, + "event_name": "push", + "subscription_id": "s1", + "plugin_unique_identifier": "uid-1", + } + } + + +class TestPluginTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_workflow_args_on_success(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={"repo": "dify"}, + cancelled=False, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"repo": "dify"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_invoke_cancelled(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={}, + cancelled=True, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + +class TestWebhookTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_uses_inputs_directly_when_present(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"inputs": {"key": "val"}, "webhook_data": {}}, + ) + mock_bus.poll.return_value = event + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"key": "val"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_falls_back_to_webhook_data(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"webhook_data": {"body": "raw"}}, + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc: + mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"} + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"parsed": "data"} + mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"}) + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + +class TestScheduleTriggerDebugEventPoller: + def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime): + """Set up mocks and create a schedule poller.""" + mock_redis.get.return_value = None + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + return ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 12, 0, 0) + next_run = datetime(2025, 1, 1, 13, 0, 0) # future + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 14, 0, 0) + next_run = datetime(2025, 1, 1, 12, 0, 0) # past + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + mock_redis.delete.assert_called_once() + + +class TestCreateEventPoller: + def _workflow_with_node(self, node_type: NodeType): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = node_type + return wf + + def test_creates_plugin_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, PluginTriggerDebugEventPoller) + + def test_creates_webhook_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, WebhookTriggerDebugEventPoller) + + def test_creates_schedule_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, ScheduleTriggerDebugEventPoller) + + def test_raises_for_unknown_type(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.START + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + def test_raises_when_node_config_missing(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = None + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + +class TestSelectTriggerDebugEvents: + def test_returns_first_non_none_event(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll: + expected = MagicMock() + mock_poll.return_value = expected + + result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"]) + + assert result is expected + + def test_returns_none_when_no_events(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None): + result = select_trigger_debug_events(wf, app_model, "u1", ["n1"]) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/test_provider.py b/api/tests/unit_tests/core/trigger/test_provider.py new file mode 100644 index 0000000000..3c2f297e90 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_provider.py @@ -0,0 +1,332 @@ +""" +Tests for core.trigger.provider.PluginTriggerProviderController. + +Covers: to_api_entity creation-method logic, credential validation pipeline, +schema resolution by type, event lookup, dispatch/invoke/subscribe delegation. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.entities.entities import ( + EventParameter, + EventParameterType, + OAuthSchema, + TriggerCreationMethod, +) +from core.trigger.errors import TriggerProviderCredentialValidationError +from tests.unit_tests.core.trigger.conftest import ( + i18n, + make_constructor, + make_controller, + make_event, + make_provider_config, + make_provider_entity, + make_subscription, +) + +ICON_URL = "https://cdn/icon.png" + + +class TestToApiEntity: + @patch("core.trigger.provider.PluginService") + def test_includes_icons_when_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png")) + + api = ctrl.to_api_entity() + + assert api.icon == ICON_URL + assert api.icon_dark == ICON_URL + + @patch("core.trigger.provider.PluginService") + def test_icons_none_when_absent(self, mock_plugin_svc): + ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None)) + + api = ctrl.to_api_entity() + + assert api.icon is None + assert api.icon_dark is None + mock_plugin_svc.get_plugin_icon_url.assert_not_called() + + @patch("core.trigger.provider.PluginService") + def test_manual_only_without_schemas(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + api = ctrl.to_api_entity() + + assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL] + + @patch("core.trigger.provider.PluginService") + def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.OAUTH in api.supported_creation_methods + assert TriggerCreationMethod.MANUAL in api.supported_creation_methods + + @patch("core.trigger.provider.PluginService") + def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.APIKEY in api.supported_creation_methods + + +class TestGetEvent: + def test_returns_matching_event(self): + evt = make_event("push") + ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")])) + + assert ctrl.get_event("push") is evt + + def test_returns_none_for_unknown(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event("nonexistent") is None + + +class TestGetSubscriptionDefaultProperties: + def test_returns_defaults_skipping_none(self): + config1 = make_provider_config("key1") + config1.default = "val1" + config2 = make_provider_config("key2") + config2.default = None + ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2])) + + props = ctrl.get_subscription_default_properties() + + assert props == {"key1": "val1"} + + +class TestValidateCredentials: + def test_raises_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + with pytest.raises(ValueError, match="Subscription constructor not found"): + ctrl.validate_credentials("u1", {"key": "val"}) + + def test_raises_for_missing_required_field(self): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + + with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"): + ctrl.validate_credentials("u1", {}) + + @patch("core.trigger.provider.PluginTriggerClient") + def test_passes_with_valid_credentials(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = True + + ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise + + @patch("core.trigger.provider.PluginTriggerClient") + def test_raises_when_plugin_rejects(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = None + + with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"): + ctrl.validate_credentials("u1", {"api_key": "bad"}) + + +class TestGetSupportedCredentialTypes: + def test_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_supported_credential_types() == [] + + def test_oauth_only(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY not in types + + def test_apikey_only(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.API_KEY in types + assert CredentialType.OAUTH2 not in types + + def test_both(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")]) + ctrl = make_controller( + entity=make_provider_entity( + constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth) + ) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY in types + + +class TestGetCredentialsSchema: + def test_returns_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_credentials_schema(CredentialType.API_KEY) == [] + + def test_returns_apikey_credentials(self): + cfg = make_provider_config("token") + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg]))) + + result = ctrl.get_credentials_schema(CredentialType.API_KEY) + + assert len(result) == 1 + assert result[0].name == "token" + + def test_returns_oauth_credentials(self): + oauth_cred = make_provider_config("oauth_token") + oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + result = ctrl.get_credentials_schema(CredentialType.OAUTH2) + + assert len(result) == 1 + assert result[0].name == "oauth_token" + + def test_unauthorized_returns_empty(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == [] + + def test_invalid_type_raises(self): + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor())) + with pytest.raises(ValueError, match="Invalid credential type"): + ctrl.get_credentials_schema("bogus_type") + + +class TestGetEventParameters: + def test_returns_params_for_known_event(self): + param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING) + evt = make_event("push", parameters=[param]) + ctrl = make_controller(entity=make_provider_entity(events=[evt])) + + result = ctrl.get_event_parameters("push") + + assert "branch" in result + assert result["branch"].name == "branch" + + def test_returns_empty_for_unknown_event(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event_parameters("nonexistent") == {} + + +class TestDispatch: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.dispatch_event.return_value = expected + + result = ctrl.dispatch( + request=MagicMock(), + subscription=make_subscription(), + credentials={"k": "v"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + mock_client.return_value.dispatch_event.assert_called_once() + + +class TestInvokeTriggerEvent: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.invoke_trigger_event.return_value = expected + + result = ctrl.invoke_trigger_event( + user_id="u1", + event_name="push", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + subscription=make_subscription(), + request=MagicMock(), + payload={}, + ) + + assert result is expected + + +class TestSubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_subscription(self, mock_client): + ctrl = make_controller() + mock_client.return_value.subscribe.return_value.subscription = { + "expires_at": 123, + "endpoint": "https://e", + "properties": {}, + } + + result = ctrl.subscribe_trigger( + user_id="u1", + endpoint="https://e", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.endpoint == "https://e" + + +class TestUnsubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_result(self, mock_client): + ctrl = make_controller() + mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"} + + result = ctrl.unsubscribe_trigger( + user_id="u1", + subscription=make_subscription(), + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.success is True + + +class TestRefreshTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_uses_system_user_id(self, mock_client): + ctrl = make_controller() + mock_client.return_value.refresh.return_value.subscription = { + "expires_at": 456, + "endpoint": "https://e", + "properties": {}, + } + + ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY) + + call_kwargs = mock_client.return_value.refresh.call_args[1] + assert call_kwargs["user_id"] == "system" diff --git a/api/tests/unit_tests/core/trigger/test_trigger_manager.py b/api/tests/unit_tests/core/trigger/test_trigger_manager.py new file mode 100644 index 0000000000..612be25ec9 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_trigger_manager.py @@ -0,0 +1,307 @@ +""" +Tests for core.trigger.trigger_manager.TriggerManager. + +Covers: icon URL construction, provider listing with error resilience, +double-check lock caching, error translation, EventIgnoreError -> cancelled, +and delegation to provider controller. +""" + +from __future__ import annotations + +from threading import Lock +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError +from core.trigger.errors import EventIgnoreError +from core.trigger.trigger_manager import TriggerManager +from models.provider_ids import TriggerProviderID +from tests.unit_tests.core.trigger.conftest import ( + VALID_PROVIDER_ID, + make_controller, + make_provider_entity, + make_subscription, +) + +PID = TriggerProviderID(VALID_PROVIDER_ID) +PID_STR = str(PID) + + +class TestGetTriggerPluginIcon: + @patch("core.trigger.trigger_manager.dify_config") + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_builds_correct_url(self, mock_client, mock_config): + mock_config.CONSOLE_API_URL = "https://console.example.com" + provider = MagicMock() + provider.declaration.identity.icon = "my-icon.svg" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID) + + assert "tenant_id=tenant-1" in url + assert "filename=my-icon.svg" in url + assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon") + + +class TestListPluginTriggerProviders: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_wraps_entities_into_controllers(self, mock_client): + entity = MagicMock() + entity.declaration = make_provider_entity("p1") + entity.plugin_id = "plugin-1" + entity.plugin_unique_identifier = "uid-1" + entity.provider = VALID_PROVIDER_ID + mock_client.return_value.fetch_trigger_providers.return_value = [entity] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "plugin-1" + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_skips_failing_providers(self, mock_client): + good = MagicMock() + good.declaration = make_provider_entity("good") + good.plugin_id = "good-plugin" + good.plugin_unique_identifier = "uid-good" + good.provider = VALID_PROVIDER_ID + + bad = MagicMock() + bad.declaration = make_provider_entity("bad") + bad.plugin_id = "bad-plugin" + bad.plugin_unique_identifier = "uid-bad" + bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation + + mock_client.return_value.fetch_trigger_providers.return_value = [bad, good] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "good-plugin" + + +class TestGetTriggerProvider: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_initializes_context_on_first_call(self, mock_ctx, mock_client): + # get() called 3 times: (1) try block, (2) after set, (3) under lock + mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + mock_ctx.plugin_trigger_providers.set.assert_called_once_with({}) + mock_ctx.plugin_trigger_providers_lock.set.assert_called_once() + assert result is not None + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_returns_cached_without_fetch(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached} + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.side_effect = [ + {}, # first check misses + {PID_STR: cached}, # under-lock check hits + ] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client): + cache: dict = {} + mock_ctx.plugin_trigger_providers.get.return_value = cache + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is not None + assert PID_STR in cache + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_none_fetch_raises_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.return_value = None + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone") + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error") + + with pytest.raises(PluginDaemonError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + +class TestListAllTriggerProviders: + @patch.object(TriggerManager, "list_plugin_trigger_providers") + def test_delegates_to_list_plugin(self, mock_list): + expected = [make_controller()] + mock_list.return_value = expected + + assert TriggerManager.list_all_trigger_providers("t1") is expected + mock_list.assert_called_once_with("t1") + + +class TestListTriggersByProvider: + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_provider_events(self, mock_get): + ctrl = make_controller() + mock_get.return_value = ctrl + + result = TriggerManager.list_triggers_by_provider("t1", PID) + + assert result == ctrl.get_events() + + +class TestInvokeTriggerEvent: + def _args(self): + return { + "tenant_id": "t1", + "user_id": "u1", + "provider_id": PID, + "event_name": "on_push", + "parameters": {"branch": "main"}, + "credentials": {"token": "abc"}, + "credential_type": CredentialType.API_KEY, + "subscription": make_subscription(), + "request": MagicMock(), + "payload": {"action": "push"}, + } + + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_invoke_response(self, mock_get): + ctrl = MagicMock() + expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False) + ctrl.invoke_trigger_event.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result is expected + assert result.cancelled is False + + @patch.object(TriggerManager, "get_trigger_provider") + def test_event_ignore_returns_cancelled(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip") + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result.cancelled is True + assert result.variables == {} + + @patch.object(TriggerManager, "get_trigger_provider") + def test_other_errors_propagate(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = RuntimeError("boom") + mock_get.return_value = ctrl + + with pytest.raises(RuntimeError, match="boom"): + TriggerManager.invoke_trigger_event(**self._args()) + + +class TestSubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.subscribe_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.subscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + endpoint="https://hook.test", + parameters={"f": "all"}, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + ctrl.subscribe_trigger.assert_called_once() + + +class TestUnsubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = MagicMock() + ctrl.unsubscribe_trigger.return_value = expected + mock_get.return_value = ctrl + sub = make_subscription() + + result = TriggerManager.unsubscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + subscription=sub, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + + +class TestRefreshTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.refresh_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.refresh_trigger( + tenant_id="t1", + provider_id=PID, + subscription=make_subscription(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected diff --git a/api/tests/unit_tests/core/trigger/utils/__init__.py b/api/tests/unit_tests/core/trigger/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py new file mode 100644 index 0000000000..8804526e2e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py @@ -0,0 +1,62 @@ +"""Tests for core.trigger.utils.encryption — masking logic and cache key generation.""" + +from __future__ import annotations + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.utils.encryption import ( + TriggerProviderCredentialsCache, + TriggerProviderOAuthClientParamsCache, + TriggerProviderPropertiesCache, + masked_credentials, +) + + +def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig: + return ProviderConfig( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + type=field_type, + ) + + +class TestMaskedCredentials: + def test_short_secret_fully_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "ab"}) + assert result["key"] == "**" + + def test_long_secret_partially_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "abcdef"}) + assert result["key"].startswith("ab") + assert result["key"].endswith("ef") + assert "**" in result["key"] + + def test_non_secret_field_unchanged(self): + schema = [_make_schema("host", "text-input")] + result = masked_credentials(schema, {"host": "example.com"}) + assert result["host"] == "example.com" + + def test_unknown_key_passes_through(self): + result = masked_credentials([], {"unknown": "value"}) + assert result["unknown"] == "value" + + +class TestCacheKeyGeneration: + def test_credentials_cache_key_contains_ids(self): + cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "c1" in cache.cache_key + + def test_oauth_client_cache_key_contains_ids(self): + cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + + def test_properties_cache_key_contains_ids(self): + cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "s1" in cache.cache_key diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py new file mode 100644 index 0000000000..e5879aea0a --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py @@ -0,0 +1,31 @@ +"""Tests for core.trigger.utils.endpoint — URL generation.""" + +from __future__ import annotations + +from unittest.mock import patch + +from yarl import URL + +from core.trigger.utils import endpoint + + +class TestGeneratePluginTriggerEndpointUrl: + def test_builds_correct_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123") + + assert url == "https://api.example.com/triggers/plugin/endpoint-123" + + +class TestGenerateWebhookTriggerEndpoint: + def test_non_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False) + + assert url == "https://api.example.com/triggers/webhook/sub-456" + + def test_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True) + + assert url == "https://api.example.com/triggers/webhook-debug/sub-456" diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py new file mode 100644 index 0000000000..4fa202b164 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py @@ -0,0 +1,23 @@ +"""Tests for core.trigger.utils.locks — Redis lock key builders.""" + +from __future__ import annotations + +from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys + + +class TestBuildTriggerRefreshLockKey: + def test_correct_format(self): + key = build_trigger_refresh_lock_key("tenant-1", "sub-1") + + assert key == "trigger_provider_refresh_lock:tenant-1_sub-1" + + +class TestBuildTriggerRefreshLockKeys: + def test_maps_over_pairs(self): + pairs = [("t1", "s1"), ("t2", "s2")] + + keys = build_trigger_refresh_lock_keys(pairs) + + assert len(keys) == 2 + assert keys[0] == "trigger_provider_refresh_lock:t1_s1" + assert keys[1] == "trigger_provider_refresh_lock:t2_s2" diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 0df4927697..22792eb5b3 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch import pytest +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.variables.variables import StringVariable class StubCoordinator: @@ -278,3 +280,17 @@ class TestGraphRuntimeState: assert restored_execution.started is True assert new_stub.state == "configured" + + def test_snapshot_restore_preserves_updated_conversation_variable(self): + variable_pool = VariablePool( + conversation_variables=[StringVariable(name="session_name", value="before")], + ) + variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + snapshot = state.dumps() + restored = GraphRuntimeState.from_snapshot(snapshot) + + restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) + assert restored_value is not None + assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index eb449e6d75..8c8e5977c8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,15 +2,7 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -import sys -from pathlib import Path - from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY - -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) - from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 3f458e9de9..34e714a227 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -22,6 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode from dify_graph.nodes.llm import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor import ParameterExtractorNode +from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode from dify_graph.nodes.template_transform import TemplateTransformNode from dify_graph.nodes.template_transform.template_renderer import ( @@ -65,11 +66,19 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + # LLM-like nodes now require an http_client; provide a mock by default for tests. + kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + # Provide default tool_file_manager_factory for ToolNode subclasses + from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + + if isinstance(self, _ToolNode): + kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + super().__init__( id=id, config=config, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 84d1444585..693cdf9276 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -3,14 +3,8 @@ Simple test to validate the auto-mock system without external dependencies. """ import sys -from pathlib import Path from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY - -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) - from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index e194d66ee3..e929d652fd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -8,7 +8,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.nodes.knowledge_retrieval.entities import ( + Condition, KnowledgeRetrievalNodeData, + MetadataFilteringCondition, MultipleRetrievalConfig, RerankingModelConfig, SingleRetrievalConfig, @@ -110,7 +112,6 @@ class TestKnowledgeRetrievalNode: # Assert assert node.id == node_id assert node._rag_retrieval == mock_rag_retrieval - assert node._llm_file_saver is not None def test_run_with_no_query_or_attachment( self, @@ -205,6 +206,7 @@ class TestKnowledgeRetrievalNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert "result" in result.outputs assert mock_rag_retrieval.knowledge_retrieval.called + mock_source.model_dump.assert_called_once_with(by_alias=True) def test_run_with_query_variable_multiple_mode( self, @@ -592,3 +594,106 @@ class TestFetchDatasetRetriever: # Assert assert version == "1" + + def test_resolve_metadata_filtering_conditions_templates( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_resolve_metadata_filtering_conditions should expand {{#...#}} and keep numbers/None unchanged.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": { + "title": "Knowledge Retrieval", + "type": "knowledge-retrieval", + "dataset_ids": [str(uuid.uuid4())], + "retrieval_mode": "multiple", + }, + } + # Variable in pool used by template + mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + conditions = MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.query#}}"), + Condition(name="tags", comparison_operator="in", value=["x", "{{#start.query#}}"]), + Condition(name="year", comparison_operator="=", value=2025), + ], + ) + + # Act + resolved = node._resolve_metadata_filtering_conditions(conditions) + + # Assert + assert resolved.logical_operator == "and" + assert resolved.conditions[0].value == "readme" + assert isinstance(resolved.conditions[1].value, list) + assert resolved.conditions[1].value[1] == "readme" + assert resolved.conditions[2].value == 2025 + + def test_fetch_passes_resolved_metadata_conditions( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_fetch_dataset_retriever should pass resolved metadata conditions into request.""" + # Arrange + query = "hi" + variables = {"query": query} + mock_graph_runtime_state.variable_pool.add(["start", "q"], StringSegment(value="readme")) + + node_data = KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + multiple_retrieval_config=MultipleRetrievalConfig( + top_k=4, + score_threshold=0.0, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RerankingModelConfig(provider="cohere", model="rerank-v2"), + ), + metadata_filtering_mode="manual", + metadata_filtering_conditions=MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.q#}}"), + ], + ), + ) + + node_id = str(uuid.uuid4()) + config = {"id": node_id, "data": node_data.model_dump()} + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + mock_rag_retrieval.knowledge_retrieval.return_value = [] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + # Act + node._fetch_dataset_retriever(node_data=node_data, variables=variables) + + # Assert the passed request has resolved value + call_args = mock_rag_retrieval.knowledge_retrieval.call_args + request = call_args[1]["request"] + assert request.metadata_filtering_conditions is not None + assert request.metadata_filtering_conditions.conditions[0].value == "readme" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index a3afd1ed5c..b0f0fd428b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -1,10 +1,10 @@ import uuid from typing import NamedTuple from unittest import mock +from unittest.mock import MagicMock import httpx import pytest -from sqlalchemy import Engine from core.helper import ssrf_proxy from core.tools import signature @@ -44,7 +44,6 @@ class TestFileSaverImpl: ) mock_tool_file.id = _gen_id() mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - mocked_engine = mock.MagicMock(spec=Engine) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) @@ -53,11 +52,12 @@ class TestFileSaverImpl: # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) mocked_sign_file.return_value = mock_signed_url + http_client = MagicMock() storage_file_manager = FileSaverImpl( user_id=user_id, tenant_id=tenant_id, - engine_factory=mocked_engine, + http_client=http_client, ) file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) @@ -87,16 +87,18 @@ class TestFileSaverImpl: status_code=401, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response + file_saver = FileSaverImpl( user_id=_gen_id(), tenant_id=_gen_id(), + http_client=http_client, ) - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) with pytest.raises(httpx.HTTPStatusError) as exc: file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_get.assert_called_once_with(_TEST_URL) + http_client.get.assert_called_once_with(_TEST_URL) assert exc.value.response.status_code == 401 def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): @@ -112,8 +114,10 @@ class TestFileSaverImpl: headers={"Content-Type": mime_type}, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) mock_tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 90308facc3..d56035b6bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -111,6 +111,7 @@ def llm_node( "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -120,6 +121,7 @@ def llm_node( model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node @@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node, mock_file_saver diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 5e20b1e12f..13275d4be6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -16,6 +16,7 @@ from dify_graph.nodes.document_extractor.node import ( _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, + _normalize_docx_zip, ) from dify_graph.variables import ArrayFileSegment from dify_graph.variables.segments import ArrayStringSegment @@ -86,6 +87,38 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s assert "is not an ArrayFileSegment" in result.error +def test_run_empty_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """Empty file list should return SUCCEEDED with empty documents and ArrayStringSegment([]).""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Provide an actual ArrayFileSegment with an empty list + mock_graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(value=[]) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + +def test_run_none_only_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """A file list containing only None (e.g., [None]) should be filtered to [] and succeed.""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Use a Mock to bypass type validation for None entries in the list + afs = Mock(spec=ArrayFileSegment) + afs.value = [None] + mock_graph_runtime_state.variable_pool.get.return_value = afs + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + @pytest.mark.parametrize( ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ @@ -385,3 +418,58 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n" assert expected_manual == result + + +def _make_docx_zip(use_backslash: bool) -> bytes: + """Helper to build a minimal in-memory DOCX zip. + + When use_backslash=True the ZIP entry names use backslash separators + (as produced by Evernote on Windows), otherwise forward slashes are used. + """ + import zipfile + + sep = "\\" if use_backslash else "/" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr("[Content_Types].xml", b"") + zf.writestr(f"_rels{sep}.rels", b"") + zf.writestr(f"word{sep}document.xml", b"") + zf.writestr(f"word{sep}_rels{sep}document.xml.rels", b"") + return buf.getvalue() + + +def test_normalize_docx_zip_replaces_backslashes(): + """ZIP entries with backslash separators must be rewritten to forward slashes.""" + import zipfile + + malformed = _make_docx_zip(use_backslash=True) + fixed = _normalize_docx_zip(malformed) + + with zipfile.ZipFile(io.BytesIO(fixed)) as zf: + names = zf.namelist() + + assert "word/document.xml" in names + assert "word/_rels/document.xml.rels" in names + # No entry should contain a backslash after normalization + assert all("\\" not in name for name in names) + + +def test_normalize_docx_zip_leaves_forward_slash_unchanged(): + """ZIP entries that already use forward slashes must not be modified.""" + import zipfile + + normal = _make_docx_zip(use_backslash=False) + fixed = _normalize_docx_zip(normal) + + with zipfile.ZipFile(io.BytesIO(fixed)) as zf: + names = zf.namelist() + + assert "word/document.xml" in names + assert "word/_rels/document.xml.rels" in names + + +def test_normalize_docx_zip_returns_original_on_bad_zip(): + """Non-zip bytes must be returned as-is without raising.""" + garbage = b"not a zip file at all" + result = _normalize_docx_zip(garbage) + assert result == garbage diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 11554169e1..3cbd96dfef 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -31,6 +31,7 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) + from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { @@ -69,11 +70,16 @@ def tool_node(monkeypatch) -> ToolNode: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] + + # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor + tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + node = ToolNode( id="node-instance", config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + tool_file_manager_factory=tool_file_manager_factory, ) return node diff --git a/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py b/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py new file mode 100644 index 0000000000..7a537b0502 --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py @@ -0,0 +1,172 @@ +"""Tests for Celery SQL comment context injection.""" + +from unittest.mock import MagicMock, patch + +from opentelemetry import context + + +class TestBuildCelerySqlcommenterTags: + """Tests for _build_celery_sqlcommenter_tags.""" + + def test_includes_framework_and_task_name(self): + """Tags include celery framework version and task name.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.async_workflow_tasks.execute_workflow_team" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "framework" in tags + assert tags["framework"].startswith("celery:") + assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team" + + def test_includes_celery_retries_when_nonzero(self): + """celery_retries is included when retries > 0.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 3 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["celery_retries"] == 3 + + def test_omits_celery_retries_when_zero(self): + """celery_retries is omitted when retries is 0.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "celery_retries" not in tags + + def test_includes_routing_key_from_delivery_info(self): + """routing_key is included when present in delivery_info.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {"routing_key": "workflow_based_app_execution"} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["routing_key"] == "workflow_based_app_execution" + + def test_includes_traceparent_when_available(self): + """traceparent is included when injectable from current context.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + traceparent = "00-5db86c23fa8d05b67db315694b518684-737bbf30cdcda066-00" + with patch( + "extensions.otel.celery_sqlcommenter._get_traceparent", + return_value=traceparent, + ): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["traceparent"] == traceparent + + def test_handles_task_without_request(self): + """Gracefully handles task without request attribute.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + del task.request + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "framework" in tags + assert "task_name" in tags + + +class TestTaskPrerunPostrunHandlers: + """Tests for task_prerun and task_postrun signal handlers.""" + + def test_prerun_sets_context_postrun_detaches(self): + """task_prerun attaches SQLCOMMENTER context; task_postrun detaches it.""" + from extensions.otel.celery_sqlcommenter import ( + _SQLCOMMENTER_CONTEXT_KEY, + _TOKEN_ATTR, + _on_task_postrun, + _on_task_prerun, + ) + + clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None) + token = context.attach(clean_ctx) + try: + task = MagicMock() + task.name = "tasks.async_workflow_tasks.execute_workflow_team" + task.request = MagicMock() + task.request.retries = 1 + task.request.delivery_info = {"routing_key": "workflow_based_app_execution"} + + with patch( + "extensions.otel.celery_sqlcommenter._get_traceparent", + return_value="00-abc123-def456-00", + ): + _on_task_prerun(task=task) + + tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags is not None + assert tags["framework"].startswith("celery:") + assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team" + assert tags["celery_retries"] == 1 + assert tags["routing_key"] == "workflow_based_app_execution" + assert tags["traceparent"] == "00-abc123-def456-00" + assert hasattr(task, _TOKEN_ATTR) + + _on_task_postrun(task=task) + + tags_after = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags_after is None + assert not hasattr(task, _TOKEN_ATTR) + finally: + context.detach(token) + + def test_prerun_skips_when_no_task(self): + """prerun does nothing when task is missing from kwargs.""" + from extensions.otel.celery_sqlcommenter import ( + _SQLCOMMENTER_CONTEXT_KEY, + _on_task_prerun, + ) + + clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None) + token = context.attach(clean_ctx) + try: + _on_task_prerun() + tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags is None + finally: + context.detach(token) + + def test_postrun_skips_when_no_token(self): + """postrun does nothing when task has no token (e.g. prerun was skipped).""" + from extensions.otel.celery_sqlcommenter import _on_task_postrun + + task = MagicMock() + _on_task_postrun(task=task) diff --git a/api/tests/unit_tests/services/dataset_permission_service.py b/api/tests/unit_tests/services/dataset_permission_service.py index b687f472a5..e098e90455 100644 --- a/api/tests/unit_tests/services/dataset_permission_service.py +++ b/api/tests/unit_tests/services/dataset_permission_service.py @@ -258,323 +258,6 @@ class DatasetPermissionTestDataFactory: return [{"user_id": user_id} for user_id in user_ids] -# ============================================================================ -# Tests for get_dataset_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceGetPartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.get_dataset_partial_member_list method. - - This test class covers the retrieval of partial member lists for datasets, - which returns a list of account IDs that have explicit permissions for - a given dataset. - - The get_dataset_partial_member_list method: - 1. Queries DatasetPermission table for the dataset ID - 2. Selects account_id values - 3. Returns list of account IDs - - Test scenarios include: - - Retrieving list with multiple members - - Retrieving list with single member - - Retrieving empty list (no partial members) - - Database query validation - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - query construction and execution. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_get_dataset_partial_member_list_with_members(self, mock_db_session): - """ - Test retrieving partial member list with multiple members. - - Verifies that when a dataset has multiple partial members, all - account IDs are returned correctly. - - This test ensures: - - Query is constructed correctly - - All account IDs are returned - - Database query is executed - """ - # Arrange - dataset_id = "dataset-123" - expected_account_ids = ["user-456", "user-789", "user-012"] - - # Mock the scalars query to return account IDs - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = expected_account_ids - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == expected_account_ids - assert len(result) == 3 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - def test_get_dataset_partial_member_list_with_single_member(self, mock_db_session): - """ - Test retrieving partial member list with single member. - - Verifies that when a dataset has only one partial member, the - single account ID is returned correctly. - - This test ensures: - - Query works correctly for single member - - Result is a list with one element - - Database query is executed - """ - # Arrange - dataset_id = "dataset-123" - expected_account_ids = ["user-456"] - - # Mock the scalars query to return single account ID - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = expected_account_ids - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == expected_account_ids - assert len(result) == 1 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - def test_get_dataset_partial_member_list_empty(self, mock_db_session): - """ - Test retrieving partial member list when no members exist. - - Verifies that when a dataset has no partial members, an empty - list is returned. - - This test ensures: - - Empty list is returned correctly - - Query is executed even when no results - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the scalars query to return empty list - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = [] - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == [] - assert len(result) == 0 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - -# ============================================================================ -# Tests for update_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceUpdatePartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.update_partial_member_list method. - - This test class covers the update of partial member lists for datasets, - which replaces the existing partial member list with a new one. - - The update_partial_member_list method: - 1. Deletes all existing DatasetPermission records for the dataset - 2. Creates new DatasetPermission records for each user in the list - 3. Adds all new permissions to the session - 4. Commits the transaction - 5. Rolls back on error - - Test scenarios include: - - Adding new partial members - - Updating existing partial members - - Replacing entire member list - - Handling empty member list - - Database transaction handling - - Error handling and rollback - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database operations including queries, adds, commits, and rollbacks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_update_partial_member_list_add_new_members(self, mock_db_session): - """ - Test adding new partial members to a dataset. - - Verifies that when updating with new members, the old members - are deleted and new members are added correctly. - - This test ensures: - - Old permissions are deleted - - New permissions are created - - All permissions are added to session - - Transaction is committed - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-456", "user-789"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - mock_query.where.assert_called() - - # Verify new permissions were added - mock_db_session.add_all.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - # Verify no rollback occurred - mock_db_session.rollback.assert_not_called() - - def test_update_partial_member_list_replace_existing(self, mock_db_session): - """ - Test replacing existing partial members with new ones. - - Verifies that when updating with a different member list, the - old members are removed and new members are added. - - This test ensures: - - Old permissions are deleted - - New permissions replace old ones - - Transaction is committed successfully - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-999", "user-888"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - - # Verify new permissions were added - mock_db_session.add_all.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_update_partial_member_list_empty_list(self, mock_db_session): - """ - Test updating with empty member list (clearing all members). - - Verifies that when updating with an empty list, all existing - permissions are deleted and no new permissions are added. - - This test ensures: - - Old permissions are deleted - - No new permissions are added - - Transaction is committed - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = [] - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - - # Verify add_all was called with empty list - mock_db_session.add_all.assert_called_once_with([]) - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_update_partial_member_list_database_error_rollback(self, mock_db_session): - """ - Test error handling and rollback on database error. - - Verifies that when a database error occurs during the update, - the transaction is rolled back and the error is re-raised. - - This test ensures: - - Error is caught and handled - - Transaction is rolled back - - Error is re-raised - - No commit occurs after error - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-456"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Mock commit to raise an error - database_error = Exception("Database connection error") - mock_db_session.commit.side_effect = database_error - - # Act & Assert - with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Verify rollback was called - mock_db_session.rollback.assert_called_once() - - # ============================================================================ # Tests for check_permission # ============================================================================ @@ -776,144 +459,6 @@ class TestDatasetPermissionServiceCheckPermission: mock_get_partial_member_list.assert_called_once_with(dataset.id) -# ============================================================================ -# Tests for clear_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceClearPartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.clear_partial_member_list method. - - This test class covers the clearing of partial member lists, which removes - all DatasetPermission records for a given dataset. - - The clear_partial_member_list method: - 1. Deletes all DatasetPermission records for the dataset - 2. Commits the transaction - 3. Rolls back on error - - Test scenarios include: - - Clearing list with existing members - - Clearing empty list (no members) - - Database transaction handling - - Error handling and rollback - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database operations including queries, deletes, commits, and rollbacks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_clear_partial_member_list_success(self, mock_db_session): - """ - Test successful clearing of partial member list. - - Verifies that when clearing a partial member list, all permissions - are deleted and the transaction is committed. - - This test ensures: - - All permissions are deleted - - Transaction is committed - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Assert - # Verify query was executed - mock_db_session.query.assert_called() - - # Verify delete was called - mock_query.where.assert_called() - mock_query.delete.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - # Verify no rollback occurred - mock_db_session.rollback.assert_not_called() - - def test_clear_partial_member_list_empty_list(self, mock_db_session): - """ - Test clearing partial member list when no members exist. - - Verifies that when clearing an already empty list, the operation - completes successfully without errors. - - This test ensures: - - Operation works correctly for empty lists - - Transaction is committed - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Assert - # Verify query was executed - mock_db_session.query.assert_called() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_clear_partial_member_list_database_error_rollback(self, mock_db_session): - """ - Test error handling and rollback on database error. - - Verifies that when a database error occurs during clearing, - the transaction is rolled back and the error is re-raised. - - This test ensures: - - Error is caught and handled - - Transaction is rolled back - - Error is re-raised - - No commit occurs after error - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Mock commit to raise an error - database_error = Exception("Database connection error") - mock_db_session.commit.side_effect = database_error - - # Act & Assert - with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Verify rollback was called - mock_db_session.rollback.assert_called_once() - - # ============================================================================ # Tests for DatasetService.check_dataset_permission # ============================================================================ @@ -1047,72 +592,6 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_members_with_permission_success(self, mock_db_session): - """ - Test that user with explicit permission can access partial_members dataset. - - Verifies that when a user has an explicit DatasetPermission record - for a partial_members dataset, they can access it successfully. - - This test ensures: - - Explicit permissions are checked correctly - - Users with permissions can access - - Database query is executed - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return permission record - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=user.id - ) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.first.return_value = mock_permission - mock_db_session.query.return_value = mock_query - - # Act (should not raise) - DatasetService.check_dataset_permission(dataset, user) - - # Assert - # Verify permission query was executed - mock_db_session.query.assert_called() - - def test_check_dataset_permission_partial_members_without_permission_error(self, mock_db_session): - """ - Test error when user without permission tries to access partial_members dataset. - - Verifies that when a user does not have an explicit DatasetPermission - record for a partial_members dataset, a NoPermissionError is raised. - - This test ensures: - - Missing permissions are detected - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return None (no permission) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.first.return_value = None # No permission found - mock_db_session.query.return_value = mock_query - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_members_creator_success(self, mock_db_session): """ Test that creator can access partial_members dataset without explicit permission. @@ -1311,72 +790,6 @@ class TestDatasetServiceCheckDatasetOperatorPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - def test_check_dataset_operator_permission_partial_members_with_permission_success(self, mock_db_session): - """ - Test that user with explicit permission can access partial_members dataset. - - Verifies that when a user has an explicit DatasetPermission record - for a partial_members dataset, they can access it successfully. - - This test ensures: - - Explicit permissions are checked correctly - - Users with permissions can access - - Database query is executed - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return permission records - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=user.id - ) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.all.return_value = [mock_permission] # User has permission - mock_db_session.query.return_value = mock_query - - # Act (should not raise) - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - # Assert - # Verify permission query was executed - mock_db_session.query.assert_called() - - def test_check_dataset_operator_permission_partial_members_without_permission_error(self, mock_db_session): - """ - Test error when user without permission tries to access partial_members dataset. - - Verifies that when a user does not have an explicit DatasetPermission - record for a partial_members dataset, a NoPermissionError is raised. - - This test ensures: - - Missing permissions are detected - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return empty list (no permission) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.all.return_value = [] # No permissions found - mock_db_session.query.return_value = mock_query - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - # ============================================================================ # Additional Documentation and Notes diff --git a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py new file mode 100644 index 0000000000..d5f34d00b9 --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py @@ -0,0 +1,93 @@ +"""Unit tests for PluginManagerService. + +This module covers the pre-uninstall plugin hook behavior: +- Successful API call: no exception raised, correct request sent +- API failure: soft-fail (logs and does not re-raise) +""" + +from unittest.mock import patch + +from httpx import HTTPStatusError + +from configs import dify_config +from services.enterprise.plugin_manager_service import ( + PluginManagerService, + PreUninstallPluginRequest, +) + + +class TestTryPreUninstallPlugin: + def test_try_pre_uninstall_plugin_success(self): + body = PreUninstallPluginRequest( + tenant_id="tenant-123", + plugin_unique_identifier="com.example.my_plugin", + ) + + with patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request: + mock_send_request.return_value = {} + + PluginManagerService.try_pre_uninstall_plugin(body) + + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"}, + raise_for_status=True, + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + + def test_try_pre_uninstall_plugin_http_error_soft_fails(self): + body = PreUninstallPluginRequest( + tenant_id="tenant-456", + plugin_unique_identifier="com.example.other_plugin", + ) + + with ( + patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request, + patch("services.enterprise.plugin_manager_service.logger") as mock_logger, + ): + mock_send_request.side_effect = HTTPStatusError( + "502 Bad Gateway", + request=None, + response=None, + ) + + PluginManagerService.try_pre_uninstall_plugin(body) + + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"}, + raise_for_status=True, + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + mock_logger.exception.assert_called_once() + + def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self): + body = PreUninstallPluginRequest( + tenant_id="tenant-789", + plugin_unique_identifier="com.example.failing_plugin", + ) + + with ( + patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request, + patch("services.enterprise.plugin_manager_service.logger") as mock_logger, + ): + mock_send_request.side_effect = ConnectionError("network unreachable") + + PluginManagerService.try_pre_uninstall_plugin(body) + + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"}, + raise_for_status=True, + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + mock_logger.exception.assert_called_once() diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py new file mode 100644 index 0000000000..a34defeba9 --- /dev/null +++ b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py @@ -0,0 +1,309 @@ +import datetime +import os +from unittest.mock import MagicMock, patch + +import pytest + +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +class TestMessagesCleanService: + @pytest.fixture(autouse=True) + def mock_db_engine(self): + with patch("services.retention.conversation.messages_clean_service.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db.engine + + @pytest.fixture + def mock_db_session(self, mock_db_engine): + with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + yield mock_session + + @pytest.fixture + def mock_policy(self): + policy = MagicMock(spec=BillingDisabledPolicy) + return policy + + def test_run_calls_clean_messages(self, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + with patch.object(service, "_clean_messages_by_time_range") as mock_clean: + mock_clean.return_value = {"total_deleted": 5} + result = service.run() + assert result == {"total_deleted": 5} + mock_clean.assert_called_once() + + def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy): + # Arrange + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + end_before=end_before, + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock( + rowcount=1 + ), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute) + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete messages + MagicMock(all=lambda: []), # next batch empty + ] + + # Reset side_effect to be more robust + # The service calls session.execute for: + # 1. Fetch messages + # 2. Fetch apps + # 3. Batch delete relations (8 calls if IDs exist) + # 4. Delete messages + + mock_returns = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps + ] + # 8 deletes for relations + mock_returns.extend([MagicMock() for _ in range(8)]) + # 1 delete for messages + mock_returns.append(MagicMock(rowcount=1)) + # Final fetch messages (empty) + mock_returns.append(MagicMock(all=lambda: [])) + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] + + # Act + with patch("services.retention.conversation.messages_clean_service.time.sleep"): + stats = service.run() + + # Assert + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 1 + assert stats["batches"] == 2 + + def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy): + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + start_from=start_from, + end_before=end_before, + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: []), # No messages + ] + + stats = service.run() + assert stats["total_messages"] == 0 + + def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy): + # Test pagination with cursor + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + end_before=end_before, + batch_size=1, + ) + + msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0) + msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0) + + mock_returns = [] + # Batch 1 + mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)])) + mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + + # Batch 2 + mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)])) + mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + + # Batch 3 + mock_returns.append(MagicMock(all=lambda: [])) + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified + + with patch("services.retention.conversation.messages_clean_service.time.sleep"): + stats = service.run() + + assert stats["batches"] == 3 + assert stats["total_messages"] == 2 + + def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + dry_run=True, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock(all=lambda: []), # next batch empty + ] + mock_policy.filter_message_ids.return_value = ["msg1"] + + with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample: + mock_sample.return_value = ["msg1"] + stats = service.run() + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 0 # Dry run + mock_sample.assert_called() + + def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: []), # apps NOT found + MagicMock(all=lambda: []), # next batch empty + ] + + stats = service.run() + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 0 + + def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: []), # next batch empty + ] + + # We need to successfully execute line 228 and 229, then return empty at 251. + # line 228: raw_messages = list(session.execute(msg_stmt).all()) + # line 251: app_ids = list({msg.app_id for msg in messages}) + + calls = [] + + def list_side_effect(arg): + calls.append(arg) + if len(calls) == 2: # This is the second call to list() in the loop + return [] + return list(arg) + + with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect): + stats = service.run() + assert stats["batches"] == 2 + assert stats["total_messages"] == 1 + + def test_from_time_range_validation(self, mock_policy): + now = datetime.datetime.now() + # Test start_from >= end_before + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range(mock_policy, now, now) + + # Test batch_size <= 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0) + + def test_from_time_range_success(self, mock_policy): + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 2, 1) + # Mock logger to avoid actual logging if needed, though it's fine + service = MessagesCleanService.from_time_range(mock_policy, start, end) + assert service._start_from == start + assert service._end_before == end + + def test_from_days_validation(self, mock_policy): + # Test days < 0 + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(mock_policy, days=-1) + + # Test batch_size <= 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(mock_policy, days=30, batch_size=0) + + def test_from_days_success(self, mock_policy): + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: + fixed_now = datetime.datetime(2024, 6, 1) + mock_now.return_value = fixed_now + + service = MessagesCleanService.from_days(mock_policy, days=10) + assert service._start_from is None + assert service._end_before == fixed_now - datetime.timedelta(days=10) + + def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock(all=lambda: []), # next batch empty + ] + mock_policy.filter_message_ids.return_value = [] # Policy says NO + + stats = service.run() + assert stats["total_messages"] == 1 + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_batch_delete_message_relations_empty(self, mock_db_session): + MessagesCleanService._batch_delete_message_relations(mock_db_session, []) + mock_db_session.execute.assert_not_called() + + def test_batch_delete_message_relations_with_ids(self, mock_db_session): + MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) + assert mock_db_session.execute.call_count == 8 # 8 tables to clean up + + @patch.dict(os.environ, {"SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL": "500"}) + def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_returns = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + ] + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + mock_returns.append(MagicMock(all=lambda: [])) # next batch empty + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] + + with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: + with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: + mock_uniform.return_value = 300.0 + service.run() + mock_uniform.assert_called_with(0, 500) + mock_sleep.assert_called_with(0.3) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..0013cde79e --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,499 @@ +""" +Unit tests for WorkflowRunCleanup service. +""" + +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +def make_run(tenant_id: str = "t1", run_id: str = "r1", created_at: datetime.datetime | None = None): + run = MagicMock() + run.tenant_id = tenant_id + run.id = run_id + run.created_at = created_at or datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC) + return run + + +@pytest.fixture +def mock_repo(): + return MagicMock() + + +@pytest.fixture +def cleanup(mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + yield WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + +# --------------------------------------------------------------------------- +# Constructor validation +# --------------------------------------------------------------------------- + + +class TestWorkflowRunCleanupInit: + def test_only_start_from_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="both set or both omitted"): + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_only_end_before_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="both set or both omitted"): + WorkflowRunCleanup( + days=30, + batch_size=10, + end_before=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_end_before_not_greater_than_start_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="end_before must be greater than start_from"): + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 6, 1), + end_before=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_equal_start_end_raises(self, mock_repo): + dt = datetime.datetime(2024, 1, 1) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=10, start_from=dt, end_before=dt, workflow_run_repo=mock_repo) + + def test_zero_batch_size_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="batch_size must be greater than 0"): + WorkflowRunCleanup(days=30, batch_size=0, workflow_run_repo=mock_repo) + + def test_negative_batch_size_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=-1, workflow_run_repo=mock_repo) + + def test_valid_window_init(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 7 + cfg.BILLING_ENABLED = False + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 6, 1) + c = WorkflowRunCleanup(days=30, batch_size=5, start_from=start, end_before=end, workflow_run_repo=mock_repo) + assert c.window_start == start + assert c.window_end == end + + +# --------------------------------------------------------------------------- +# _empty_related_counts / _format_related_counts +# --------------------------------------------------------------------------- + + +class TestStaticHelpers: + def test_empty_related_counts(self): + counts = WorkflowRunCleanup._empty_related_counts() + assert counts == { + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + def test_format_related_counts(self): + counts = { + "node_executions": 1, + "offloads": 2, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + } + result = WorkflowRunCleanup._format_related_counts(counts) + assert "node_executions 1" in result + assert "offloads 2" in result + assert "trigger_logs 4" in result + + +# --------------------------------------------------------------------------- +# _expiration_datetime +# --------------------------------------------------------------------------- + + +class TestExpirationDatetime: + def test_negative_returns_none(self, cleanup): + assert cleanup._expiration_datetime("t1", -1) is None + + def test_valid_timestamp(self, cleanup): + ts = int(datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC).timestamp()) + result = cleanup._expiration_datetime("t1", ts) + assert result is not None + assert result.year == 2025 + + def test_overflow_returns_none(self, cleanup): + result = cleanup._expiration_datetime("t1", 2**62) + assert result is None + + +# --------------------------------------------------------------------------- +# _is_within_grace_period +# --------------------------------------------------------------------------- + + +class TestIsWithinGracePeriod: + def test_zero_grace_period_returns_false(self, cleanup): + cleanup.free_plan_grace_period_days = 0 + assert cleanup._is_within_grace_period("t1", {"expiration_date": 9999999999}) is False + + def test_within_grace_period(self, cleanup): + cleanup.free_plan_grace_period_days = 30 + # expired just 1 day ago + expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=1) + ts = int(expired.timestamp()) + assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is True + + def test_outside_grace_period(self, cleanup): + cleanup.free_plan_grace_period_days = 5 + # expired 100 days ago + expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=100) + ts = int(expired.timestamp()) + assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is False + + def test_missing_expiration_date_returns_false(self, cleanup): + cleanup.free_plan_grace_period_days = 30 + assert cleanup._is_within_grace_period("t1", {"expiration_date": -1}) is False + + +# --------------------------------------------------------------------------- +# _get_cleanup_whitelist +# --------------------------------------------------------------------------- + + +class TestGetCleanupWhitelist: + def test_billing_disabled_returns_empty(self, cleanup): + cleanup._cleanup_whitelist = None + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + result = cleanup._get_cleanup_whitelist() + assert result == set() + + def test_billing_enabled_fetches_whitelist(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_expired_subscription_cleanup_whitelist.return_value = ["t1", "t2"] + result = c._get_cleanup_whitelist() + assert result == {"t1", "t2"} + + def test_cached_whitelist_returned(self, cleanup): + cleanup._cleanup_whitelist = {"cached"} + result = cleanup._get_cleanup_whitelist() + assert result == {"cached"} + + def test_billing_service_error_returns_empty(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_expired_subscription_cleanup_whitelist.side_effect = Exception("error") + result = c._get_cleanup_whitelist() + assert result == set() + + +# --------------------------------------------------------------------------- +# _filter_free_tenants +# --------------------------------------------------------------------------- + + +class TestFilterFreeTenants: + def test_billing_disabled_all_tenants_free(self, cleanup): + result = cleanup._filter_free_tenants(["t1", "t2"]) + assert result == {"t1", "t2"} + + def test_empty_tenants_returns_empty(self, cleanup): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = True + result = cleanup._filter_free_tenants([]) + assert result == set() + + def test_whitelisted_tenant_excluded(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = {"t1"} + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + from enums.cloud_plan import CloudPlan + + bs.get_plan_bulk_with_cache.return_value = { + "t1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "t2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + result = c._filter_free_tenants(["t1", "t2"]) + assert "t1" not in result + assert "t2" in result + + def test_paid_tenant_excluded(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.return_value = { + "t1": {"plan": "professional", "expiration_date": -1}, + } + result = c._filter_free_tenants(["t1"]) + assert result == set() + + def test_missing_billing_info_treats_as_non_free(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.return_value = {} + result = c._filter_free_tenants(["t1"]) + assert result == set() + + def test_billing_bulk_error_treats_as_non_free(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.side_effect = Exception("fail") + result = c._filter_free_tenants(["t1"]) + assert result == set() + + +# --------------------------------------------------------------------------- +# run() — delete mode +# --------------------------------------------------------------------------- + + +class TestRunDeleteMode: + def _make_cleanup(self, mock_repo, billing_enabled=False): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = billing_enabled + return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + def test_no_rows_stops_immediately(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + + def test_all_paid_skips_delete(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + c = self._make_cleanup(mock_repo) + # billing disabled -> all free; but let's override _filter_free_tenants to return empty + c._filter_free_tenants = MagicMock(return_value=set()) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + + def test_runs_deleted_successfully(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.delete_runs_with_related.return_value = { + "runs": 1, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.time.sleep"): + c.run() + mock_repo.delete_runs_with_related.assert_called_once() + + def test_delete_exception_reraises(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.delete_runs_with_related.side_effect = RuntimeError("db error") + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + with pytest.raises(RuntimeError): + c.run() + + def test_summary_with_window_start(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 6, 1), + workflow_run_repo=mock_repo, + ) + c.run() + + +# --------------------------------------------------------------------------- +# run() — dry run mode +# --------------------------------------------------------------------------- + + +class TestRunDryRunMode: + def _make_dry_cleanup(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo, dry_run=True) + + def test_dry_run_no_delete_called(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.count_runs_with_related.return_value = { + "node_executions": 2, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 1, + "pauses": 0, + "pause_reasons": 0, + } + c = self._make_dry_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + mock_repo.count_runs_with_related.assert_called_once() + + def test_dry_run_summary_with_window_start(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 6, 1), + workflow_run_repo=mock_repo, + dry_run=True, + ) + c.run() + + def test_dry_run_all_paid_skips_count(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + c = self._make_dry_cleanup(mock_repo) + c._filter_free_tenants = MagicMock(return_value=set()) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.count_runs_with_related.assert_not_called() + + +# --------------------------------------------------------------------------- +# _delete_trigger_logs / _count_trigger_logs +# --------------------------------------------------------------------------- + + +class TestTriggerLogMethods: + def test_delete_trigger_logs(self, cleanup): + session = MagicMock() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository" + ) as RepoClass: + instance = RepoClass.return_value + instance.delete_by_run_ids.return_value = 5 + result = cleanup._delete_trigger_logs(session, ["r1", "r2"]) + assert result == 5 + + def test_count_trigger_logs(self, cleanup): + session = MagicMock() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository" + ) as RepoClass: + instance = RepoClass.return_value + instance.count_by_run_ids.return_value = 3 + result = cleanup._count_trigger_logs(session, ["r1"]) + assert result == 3 + + +# --------------------------------------------------------------------------- +# _count_node_executions / _delete_node_executions +# --------------------------------------------------------------------------- + + +class TestNodeExecutionMethods: + def test_count_node_executions(self, cleanup): + session = MagicMock() + session.get_bind.return_value = MagicMock() + runs = [make_run("t1", "r1")] + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory" + ) as factory: + repo = factory.create_api_workflow_node_execution_repository.return_value + repo.count_by_runs.return_value = (10, 2) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"): + result = cleanup._count_node_executions(session, runs) + assert result == (10, 2) + + def test_delete_node_executions(self, cleanup): + session = MagicMock() + session.get_bind.return_value = MagicMock() + runs = [make_run("t1", "r1")] + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory" + ) as factory: + repo = factory.create_api_workflow_node_execution_repository.return_value + repo.delete_by_runs.return_value = (5, 1) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"): + result = cleanup._delete_node_executions(session, runs) + assert result == (5, 1) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py new file mode 100644 index 0000000000..9fe153c153 --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py @@ -0,0 +1,216 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy.orm import Session + +from models.workflow import WorkflowRun +from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult + + +class TestArchivedWorkflowRunDeletion: + @pytest.fixture + def mock_db(self): + with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db + + @pytest.fixture + def mock_sessionmaker(self): + with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: + mock_session = MagicMock(spec=Session) + mock_sm.return_value.return_value.__enter__.return_value = mock_session + yield mock_sm, mock_session + + @pytest.fixture + def mock_workflow_run_repo(self): + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository" + ) as mock_repo_cls: + mock_repo = MagicMock() + yield mock_repo + + def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + tenant_id = "tenant-456" + + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = run_id + mock_run.tenant_id = tenant_id + mock_session.get.return_value = mock_run + + deletion = ArchivedWorkflowRunDeletion() + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_run_ids.return_value = [run_id] + + with patch.object(deletion, "_delete_run") as mock_delete_run: + expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True) + mock_delete_run.return_value = expected_result + + result = deletion.delete_by_run_id(run_id) + + assert result == expected_result + mock_session.get.assert_called_once_with(WorkflowRun, run_id) + mock_repo.get_archived_run_ids.assert_called_once() + mock_delete_run.assert_called_once_with(mock_run) + + def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + mock_session.get.return_value = None + + deletion = ArchivedWorkflowRunDeletion() + with patch.object(deletion, "_get_workflow_run_repo"): + result = deletion.delete_by_run_id(run_id) + + assert result.success is False + assert "not found" in result.error + assert result.run_id == run_id + + def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = run_id + mock_session.get.return_value = mock_run + + deletion = ArchivedWorkflowRunDeletion() + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_run_ids.return_value = [] + + result = deletion.delete_by_run_id(run_id) + + assert result.success is False + assert "is not archived" in result.error + + def test_delete_batch(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + deletion = ArchivedWorkflowRunDeletion() + + mock_run1 = MagicMock(spec=WorkflowRun) + mock_run1.id = "run-1" + mock_run2 = MagicMock(spec=WorkflowRun) + mock_run2.id = "run-2" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2] + + with patch.object(deletion, "_delete_run") as mock_delete_run: + mock_delete_run.side_effect = [ + DeleteResult(run_id="run-1", tenant_id="t1", success=True), + DeleteResult(run_id="run-2", tenant_id="t1", success=True), + ] + + results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now()) + + assert len(results) == 2 + assert results[0].run_id == "run-1" + assert results[1].run_id == "run-2" + assert mock_delete_run.call_count == 2 + + def test_delete_run_dry_run(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=True) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + mock_run.tenant_id = "tenant-456" + + result = deletion._delete_run(mock_run) + + assert result.success is True + assert result.run_id == "run-123" + + def test_delete_run_success(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=False) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + mock_run.tenant_id = "tenant-456" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1} + + result = deletion._delete_run(mock_run) + + assert result.success is True + assert result.deleted_counts == {"workflow_runs": 1} + + def test_delete_run_exception(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=False) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deletion._delete_run(mock_run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_trigger_logs(self): + mock_session = MagicMock(spec=Session) + run_ids = ["run-1", "run-2"] + + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository" + ) as mock_repo_cls: + mock_repo = MagicMock() + mock_repo_cls.return_value = mock_repo + mock_repo.delete_by_run_ids.return_value = 5 + + count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids) + + assert count == 5 + mock_repo_cls.assert_called_once_with(mock_session) + mock_repo.delete_by_run_ids.assert_called_once_with(run_ids) + + def test_delete_node_executions(self): + mock_session = MagicMock(spec=Session) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-1" + runs = [mock_run] + + with patch( + "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository" + ) as mock_create_repo: + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + mock_repo.delete_by_runs.return_value = (1, 2) + + with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: + result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs) + + assert result == (1, 2) + mock_create_repo.assert_called_once() + mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"]) + + def test_get_workflow_run_repo(self, mock_db): + deletion = ArchivedWorkflowRunDeletion() + + with patch( + "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" + ) as mock_create_repo: + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + + # First call + repo1 = deletion._get_workflow_run_repo() + assert repo1 == mock_repo + assert deletion.workflow_run_repo == mock_repo + + # Second call (should return cached) + repo2 = deletion._get_workflow_run_repo() + assert repo2 == mock_repo + mock_create_repo.assert_called_once() diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py new file mode 100644 index 0000000000..6097bcbd61 --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -0,0 +1,1020 @@ +""" +Comprehensive unit tests for WorkflowRunRestore service. + +This file provides complete test coverage for all WorkflowRunRestore methods. +Tests are organized by functionality and include edge cases, error handling, +and both positive and negative test scenarios. +""" + +import io +import json +import zipfile +from datetime import datetime +from unittest.mock import Mock, create_autospec, patch + +import pytest + +from libs.archive_storage import ArchiveStorageNotConfiguredError +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowArchiveLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from services.retention.workflow_run.restore_archived_workflow_run import ( + SCHEMA_MAPPERS, + TABLE_MODELS, + RestoreResult, + WorkflowRunRestore, +) + + +class WorkflowRunRestoreTestDataFactory: + """ + Factory for creating test data and mock objects. + + Provides reusable methods to create consistent mock objects for testing + workflow run restore operations. + """ + + @staticmethod + def create_workflow_run_mock( + run_id: str = "run-123", + tenant_id: str = "tenant-123", + app_id: str = "app-123", + created_at: datetime | None = None, + **kwargs, + ) -> Mock: + """ + Create a mock WorkflowRun object. + + Args: + run_id: Unique identifier for the workflow run + tenant_id: Tenant/workspace identifier + app_id: Application identifier + created_at: Creation timestamp + **kwargs: Additional attributes to set on the mock + + Returns: + Mock WorkflowRun object with specified attributes + """ + run = create_autospec(WorkflowRun, instance=True) + run.id = run_id + run.tenant_id = tenant_id + run.app_id = app_id + run.created_at = created_at or datetime(2024, 1, 1, 12, 0, 0) + for key, value in kwargs.items(): + setattr(run, key, value) + return run + + @staticmethod + def create_workflow_archive_log_mock( + run_id: str = "run-123", + tenant_id: str = "tenant-123", + app_id: str = "app-123", + created_at: datetime | None = None, + **kwargs, + ) -> Mock: + """ + Create a mock WorkflowArchiveLog object. + + Args: + run_id: Unique identifier for the workflow run + tenant_id: Tenant/workspace identifier + app_id: Application identifier + created_at: Creation timestamp + **kwargs: Additional attributes to set on the mock + + Returns: + Mock WorkflowArchiveLog object with specified attributes + """ + archive_log = create_autospec(WorkflowArchiveLog, instance=True) + archive_log.workflow_run_id = run_id + archive_log.tenant_id = tenant_id + archive_log.app_id = app_id + archive_log.run_created_at = created_at or datetime(2024, 1, 1, 12, 0, 0) + for key, value in kwargs.items(): + setattr(archive_log, key, value) + return archive_log + + @staticmethod + def create_archive_zip_mock( + manifest: dict | None = None, + tables_data: dict[str, list[dict]] | None = None, + ) -> bytes: + """ + Create a mock archive zip file in memory. + + Args: + manifest: Archive manifest data + tables_data: Dictionary mapping table names to list of records + + Returns: + Bytes representing the zip file + """ + if manifest is None: + manifest = { + "schema_version": "1.0", + "tables": { + "workflow_runs": {"row_count": 1}, + "workflow_app_logs": {"row_count": 2}, + }, + } + + if tables_data is None: + tables_data = { + "workflow_runs": [{"id": "run-123", "tenant_id": "tenant-123"}], + "workflow_app_logs": [ + {"id": "log-1", "workflow_run_id": "run-123"}, + {"id": "log-2", "workflow_run_id": "run-123"}, + ], + } + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + zip_file.writestr("manifest.json", json.dumps(manifest)) + for table_name, records in tables_data.items(): + jsonl_data = "\n".join(json.dumps(record) for record in records) + zip_file.writestr(f"{table_name}.jsonl", jsonl_data) + + zip_buffer.seek(0) + return zip_buffer.getvalue() + + +# --------------------------------------------------------------------------- +# Test WorkflowRunRestore Initialization +# --------------------------------------------------------------------------- + + +class TestWorkflowRunRestoreInit: + """Tests for WorkflowRunRestore.__init__ method.""" + + def test_default_initialization(self): + """Service should initialize with default values.""" + restore = WorkflowRunRestore() + assert restore.dry_run is False + assert restore.workers == 1 + assert restore.workflow_run_repo is None + + def test_dry_run_initialization(self): + """Service should respect dry_run flag.""" + restore = WorkflowRunRestore(dry_run=True) + assert restore.dry_run is True + assert restore.workers == 1 + + def test_custom_workers_initialization(self): + """Service should accept custom workers count.""" + restore = WorkflowRunRestore(workers=5) + assert restore.workers == 5 + + def test_invalid_workers_raises_error(self): + """Service should raise ValueError for workers less than 1.""" + with pytest.raises(ValueError, match="workers must be at least 1"): + WorkflowRunRestore(workers=0) + + def test_negative_workers_raises_error(self): + """Service should raise ValueError for negative workers.""" + with pytest.raises(ValueError, match="workers must be at least 1"): + WorkflowRunRestore(workers=-1) + + +# --------------------------------------------------------------------------- +# Test _get_workflow_run_repo Method +# --------------------------------------------------------------------------- + + +class TestGetWorkflowRunRepo: + """Tests for WorkflowRunRestore._get_workflow_run_repo method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.DifyAPIRepositoryFactory") + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + @patch("services.retention.workflow_run.restore_archived_workflow_run.db") + def test_first_call_creates_repo(self, mock_db, mock_sessionmaker, mock_factory): + """First call should create and cache repository.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + mock_repo = Mock() + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + result = restore._get_workflow_run_repo() + + assert result is mock_repo + assert restore.workflow_run_repo is mock_repo + mock_sessionmaker.assert_called_once_with(bind=mock_db.engine, expire_on_commit=False) + mock_factory.create_api_workflow_run_repository.assert_called_once_with(mock_session) + + def test_cached_repo_returned(self): + """Subsequent calls should return cached repository.""" + restore = WorkflowRunRestore() + mock_repo = Mock() + restore.workflow_run_repo = mock_repo + + result = restore._get_workflow_run_repo() + + assert result is mock_repo + + +# --------------------------------------------------------------------------- +# Test _load_manifest_from_zip Method +# --------------------------------------------------------------------------- + + +class TestLoadManifestFromZip: + """Tests for WorkflowRunRestore._load_manifest_from_zip method.""" + + def test_load_valid_manifest(self): + """Should load manifest from valid zip.""" + manifest_data = {"schema_version": "1.0", "tables": {}} + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("manifest.json", json.dumps(manifest_data)) + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + result = WorkflowRunRestore._load_manifest_from_zip(archive) + + assert result == manifest_data + + def test_missing_manifest_raises_error(self): + """Should raise ValueError when manifest.json is missing.""" + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("other.txt", "data") + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + with pytest.raises(ValueError, match="manifest.json missing from archive bundle"): + WorkflowRunRestore._load_manifest_from_zip(archive) + + def test_invalid_json_raises_error(self): + """Should raise ValueError when manifest contains invalid JSON.""" + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("manifest.json", "invalid json") + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + with pytest.raises(json.JSONDecodeError): + WorkflowRunRestore._load_manifest_from_zip(archive) + + +# --------------------------------------------------------------------------- +# Test _get_schema_version Method +# --------------------------------------------------------------------------- + + +class TestGetSchemaVersion: + """Tests for WorkflowRunRestore._get_schema_version method.""" + + def test_valid_schema_version(self): + """Should return valid schema version from manifest.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": "1.0"} + result = restore._get_schema_version(manifest) + assert result == "1.0" + + def test_missing_schema_version_defaults_to_1_0(self): + """Should default to 1.0 when schema_version is missing.""" + restore = WorkflowRunRestore() + manifest = {"tables": {}} + + with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: + result = restore._get_schema_version(manifest) + + assert result == "1.0" + mock_logger.warning.assert_called_once_with("Manifest missing schema_version; defaulting to 1.0") + + def test_unsupported_schema_version_raises_error(self): + """Should raise ValueError for unsupported schema version.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": "2.0"} + + with pytest.raises(ValueError, match="Unsupported schema_version 2.0"): + restore._get_schema_version(manifest) + + def test_numeric_schema_version_converted_to_string(self): + """Should convert numeric schema version to string.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": 1} + + # This should raise ValueError because "1" is not in SCHEMA_MAPPERS (only "1.0" is) + with pytest.raises(ValueError, match="Unsupported schema_version 1"): + restore._get_schema_version(manifest) + + +# --------------------------------------------------------------------------- +# Test _apply_schema_mapping Method +# --------------------------------------------------------------------------- + + +class TestApplySchemaMapping: + """Tests for WorkflowRunRestore._apply_schema_mapping method.""" + + def test_no_mapping_returns_original(self): + """Should return original record when no mapping exists.""" + restore = WorkflowRunRestore() + record = {"id": "test", "name": "test"} + result = restore._apply_schema_mapping("workflow_runs", "1.0", record) + assert result == record + + def test_mapping_applied(self): + """Should apply mapping when it exists.""" + restore = WorkflowRunRestore() + + def test_mapper(record): + return {**record, "mapped": True} + + # Add test mapper to SCHEMA_MAPPERS + original_mappers = SCHEMA_MAPPERS.copy() + SCHEMA_MAPPERS["1.0"]["test_table"] = test_mapper + + try: + record = {"id": "test"} + result = restore._apply_schema_mapping("test_table", "1.0", record) + assert result == {"id": "test", "mapped": True} + finally: + # Restore original mappers + SCHEMA_MAPPERS.clear() + SCHEMA_MAPPERS.update(original_mappers) + + +# --------------------------------------------------------------------------- +# Test _convert_datetime_fields Method +# --------------------------------------------------------------------------- + + +class TestConvertDatetimeFields: + """Tests for WorkflowRunRestore._convert_datetime_fields method.""" + + def test_iso_datetime_conversion(self): + """Should convert ISO datetime strings to datetime objects.""" + restore = WorkflowRunRestore() + + record = {"created_at": "2024-01-01T12:00:00", "name": "test"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["name"] == "test" + + def test_invalid_datetime_ignored(self): + """Should ignore invalid datetime strings.""" + restore = WorkflowRunRestore() + + record = {"created_at": "invalid-date", "name": "test"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert result["created_at"] == "invalid-date" + assert result["name"] == "test" + + def test_non_datetime_columns_unchanged(self): + """Should leave non-datetime columns unchanged.""" + restore = WorkflowRunRestore() + + record = {"id": "test", "tenant_id": "tenant-123"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert result["id"] == "test" + assert result["tenant_id"] == "tenant-123" + + +# --------------------------------------------------------------------------- +# Test _get_model_column_info Method +# --------------------------------------------------------------------------- + + +class TestGetModelColumnInfo: + """Tests for WorkflowRunRestore._get_model_column_info method.""" + + def test_column_info_extraction(self): + """Should extract column information correctly.""" + restore = WorkflowRunRestore() + + column_names, required_columns, non_nullable_with_default = restore._get_model_column_info(WorkflowRun) + + # Check that we get some expected columns + assert "id" in column_names + assert "tenant_id" in column_names + assert "app_id" in column_names + assert "created_at" in column_names + assert "created_by" in column_names + assert "status" in column_names + + # WorkflowRun model has no required columns (all have defaults or are auto-generated) + assert required_columns == set() + + # Check columns with defaults or server defaults + assert "id" in non_nullable_with_default + assert "created_at" in non_nullable_with_default + assert "elapsed_time" in non_nullable_with_default + assert "total_tokens" in non_nullable_with_default + + +# --------------------------------------------------------------------------- +# Test _restore_table_records Method +# --------------------------------------------------------------------------- + + +class TestRestoreTableRecords: + """Tests for WorkflowRunRestore._restore_table_records method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.TABLE_MODELS") + def test_unknown_table_returns_zero(self, mock_table_models): + """Should return 0 for unknown table.""" + restore = WorkflowRunRestore() + mock_table_models.get.return_value = None + + mock_session = Mock() + records = [{"id": "test"}] + + with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: + result = restore._restore_table_records(mock_session, "unknown_table", records, schema_version="1.0") + + assert result == 0 + mock_logger.warning.assert_called_once_with("Unknown table: %s", "unknown_table") + + def test_empty_records_returns_zero(self): + """Should return 0 for empty records list.""" + restore = WorkflowRunRestore() + mock_session = Mock() + + result = restore._restore_table_records(mock_session, "workflow_runs", [], schema_version="1.0") + assert result == 0 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") + @patch("services.retention.workflow_run.restore_archived_workflow_run.cast") + def test_successful_restore(self, mock_cast, mock_pg_insert): + """Should successfully restore records.""" + restore = WorkflowRunRestore() + + # Mock session and execution + mock_session = Mock() + mock_result = Mock() + mock_result.rowcount = 2 + mock_session.execute.return_value = mock_result + mock_cast.return_value = mock_result + + # Mock insert statement + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_pg_insert.return_value = mock_stmt + + records = [{"id": "test1", "tenant_id": "tenant-123"}, {"id": "test2", "tenant_id": "tenant-123"}] + + result = restore._restore_table_records(mock_session, "workflow_runs", records, schema_version="1.0") + + assert result == 2 + mock_session.execute.assert_called_once() + + def test_missing_required_columns_raises_error(self): + """Should raise ValueError for missing required columns.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + # Since WorkflowRun has no required columns, we need to test with a different model + # Let's test with a mock model that has required columns + mock_model = Mock() + + # Mock a required column + required_column = Mock() + required_column.key = "required_field" + required_column.nullable = False + required_column.default = None + required_column.server_default = None + required_column.autoincrement = False + required_column.type = Mock() + + # Mock the __table__ attribute properly + mock_table = Mock() + mock_table.columns = [required_column] + mock_model.__table__ = mock_table + + records = [{"name": "test"}] # Missing required 'required_field' + + with patch.dict(TABLE_MODELS, {"test_table": mock_model}): + with pytest.raises(ValueError, match="Missing required columns for test_table"): + restore._restore_table_records(mock_session, "test_table", records, schema_version="1.0") + + +# --------------------------------------------------------------------------- +# Test _restore_from_run Method +# --------------------------------------------------------------------------- + + +class TestRestoreFromRun: + """Tests for WorkflowRunRestore._restore_from_run method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_archive_storage_not_configured(self, mock_get_storage): + """Should handle ArchiveStorageNotConfiguredError.""" + restore = WorkflowRunRestore() + mock_get_storage.side_effect = ArchiveStorageNotConfiguredError("Storage not configured") + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: Mock()) + + assert result.success is False + assert "Storage not configured" in result.error + assert result.elapsed_time > 0 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_archive_bundle_not_found(self, mock_get_storage): + """Should handle FileNotFoundError when archive bundle is missing.""" + restore = WorkflowRunRestore() + mock_storage = Mock() + mock_storage.get_object.side_effect = FileNotFoundError("Bundle not found") + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: Mock()) + + assert result.success is False + assert "Archive bundle not found" in result.error + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_dry_run_mode(self, mock_get_storage): + """Should handle dry run mode correctly.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Create a proper mock session with context manager support + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + result = restore._restore_from_run(run, session_maker=lambda: mock_session) + + assert result.success is True + assert result.restored_counts["workflow_runs"] == 1 + assert result.restored_counts["workflow_app_logs"] == 2 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + @patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") + @patch("services.retention.workflow_run.restore_archived_workflow_run.cast") + def test_successful_restore(self, mock_cast, mock_pg_insert, mock_get_storage): + """Should successfully restore from archive.""" + restore = WorkflowRunRestore() + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + # Mock session with context manager support + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + def session_maker(): + return mock_session + + # Mock database execution to return integer counts + mock_result_workflow_runs = Mock() + mock_result_workflow_runs.rowcount = 1 + mock_result_app_logs = Mock() + mock_result_app_logs.rowcount = 2 + + # Configure session.execute to return different results based on the table + def mock_execute(stmt): + if "workflow_runs" in str(stmt): + return mock_result_workflow_runs + else: + return mock_result_app_logs + + mock_session.execute.side_effect = mock_execute + mock_cast.return_value = mock_result_workflow_runs + + # Mock insert statement + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_pg_insert.return_value = mock_stmt + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Mock repository methods + with patch.object(restore, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = Mock() + mock_get_repo.return_value = mock_repo + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=session_maker) + + assert result.success is True + assert result.restored_counts["workflow_runs"] == 1 + assert result.restored_counts["workflow_app_logs"] >= 1 # Just check it's restored + mock_session.commit.assert_called_once() + mock_repo.delete_archive_log_by_run_id.assert_called_once_with(mock_session, run.id) + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_invalid_archive_bundle(self, mock_get_storage): + """Should handle invalid archive bundle.""" + restore = WorkflowRunRestore() + + # Mock storage with invalid zip data + mock_storage = Mock() + mock_storage.get_object.return_value = b"invalid zip data" + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Create proper mock session + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: mock_session) + + assert result.success is False + # The error message comes from zipfile.BadZipFile which says "File is not a zip file" + assert "File is not a zip file" in result.error + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_workflow_archive_log_input(self, mock_get_storage): + """Should handle WorkflowArchiveLog input correctly.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + + # Create proper mock session + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + result = restore._restore_from_run(archive_log, session_maker=lambda: mock_session) + + assert result.success is True + assert result.run_id == archive_log.workflow_run_id + assert result.tenant_id == archive_log.tenant_id + + +# --------------------------------------------------------------------------- +# Test restore_batch Method +# --------------------------------------------------------------------------- + + +class TestRestoreBatch: + """Tests for WorkflowRunRestore.restore_batch method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_empty_tenant_ids_returns_empty(self, mock_sessionmaker): + """Should return empty list when tenant_ids is empty list.""" + restore = WorkflowRunRestore() + + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + result = restore.restore_batch( + tenant_ids=[], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert result == [] + + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_successful_batch_restore(self, mock_executor): + """Should successfully restore batch of workflow runs.""" + restore = WorkflowRunRestore(workers=2) + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + # Mock repository and archive logs + mock_repo = Mock() + archive_log1 = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock("run-1") + archive_log2 = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock("run-2") + mock_repo.get_archived_logs_by_time_range.return_value = [archive_log1, archive_log2] + + # Mock restore results + result1 = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={}) + result2 = RestoreResult(run_id="run-2", tenant_id="tenant-1", success=True, restored_counts={}) + + # Mock ThreadPoolExecutor with context manager support + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[result1, result2]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", side_effect=[result1, result2]): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + results = restore.restore_batch( + tenant_ids=["tenant-1"], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert len(results) == 2 + assert results[0].run_id == "run-1" + assert results[1].run_id == "run-2" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_dry_run_batch_restore(self, mock_executor): + """Should handle dry run mode for batch restore.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_logs_by_time_range.return_value = [archive_log] + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={"workflow_runs": 1}) + + # Mock ThreadPoolExecutor with context manager support + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[result]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + results = restore.restore_batch( + tenant_ids=["tenant-1"], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert len(results) == 1 + assert results[0].success is True + + +# --------------------------------------------------------------------------- +# Test restore_by_run_id Method +# --------------------------------------------------------------------------- + + +class TestRestoreByRunId: + """Tests for WorkflowRunRestore.restore_by_run_id method.""" + + def test_archive_log_not_found(self): + """Should handle case when archive log is not found.""" + restore = WorkflowRunRestore() + + mock_repo = Mock() + mock_repo.get_archived_log_by_run_id.return_value = None + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore.restore_by_run_id("nonexistent-run") + + assert result.success is False + assert "not found" in result.error + assert result.run_id == "nonexistent-run" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_successful_restore_by_id(self, mock_sessionmaker): + """Should successfully restore by run ID.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={}) + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + actual_result = restore.restore_by_run_id("run-1") + + assert actual_result.success is True + assert actual_result.run_id == "run-1" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_dry_run_restore_by_id(self, mock_sessionmaker): + """Should handle dry run mode for restore by ID.""" + restore = WorkflowRunRestore(dry_run=True) + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={"workflow_runs": 1}) + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + actual_result = restore.restore_by_run_id("run-1") + + assert actual_result.success is True + assert actual_result.run_id == "run-1" + + +# --------------------------------------------------------------------------- +# Test RestoreResult Dataclass +# --------------------------------------------------------------------------- + + +class TestRestoreResult: + """Tests for RestoreResult dataclass.""" + + def test_restore_result_creation(self): + """Should create RestoreResult with all fields.""" + result = RestoreResult( + run_id="run-123", + tenant_id="tenant-123", + success=True, + restored_counts={"workflow_runs": 1, "workflow_app_logs": 2}, + error=None, + elapsed_time=5.5, + ) + + assert result.run_id == "run-123" + assert result.tenant_id == "tenant-123" + assert result.success is True + assert result.restored_counts == {"workflow_runs": 1, "workflow_app_logs": 2} + assert result.error is None + assert result.elapsed_time == 5.5 + + def test_restore_result_with_error(self): + """Should create RestoreResult with error.""" + result = RestoreResult( + run_id="run-123", + tenant_id="tenant-123", + success=False, + restored_counts={}, + error="Something went wrong", + ) + + assert result.success is False + assert result.error == "Something went wrong" + assert result.restored_counts == {} + assert result.elapsed_time == 0.0 # Default value + + +# --------------------------------------------------------------------------- +# Test Constants and Mappings +# --------------------------------------------------------------------------- + + +class TestConstantsAndMappings: + """Tests for module constants and mappings.""" + + def test_table_models_mapping(self): + """TABLE_MODELS should contain expected table mappings.""" + expected_tables = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, + } + + assert expected_tables == TABLE_MODELS + + def test_schema_mappers_structure(self): + """SCHEMA_MAPPERS should have correct structure.""" + assert isinstance(SCHEMA_MAPPERS, dict) + assert "1.0" in SCHEMA_MAPPERS + assert isinstance(SCHEMA_MAPPERS["1.0"], dict) + + +# --------------------------------------------------------------------------- +# Integration Tests +# --------------------------------------------------------------------------- + + +class TestIntegration: + """Integration tests combining multiple components.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_full_restore_flow(self, mock_executor, mock_get_storage): + """Test complete restore flow with all components.""" + restore = WorkflowRunRestore(workers=1) + + # Mock storage + mock_storage = Mock() + manifest = { + "schema_version": "1.0", + "tables": { + "workflow_runs": {"row_count": 1}, + }, + } + tables_data = { + "workflow_runs": [ + { + "id": "run-123", + "tenant_id": "tenant-123", + "app_id": "app-123", + "created_at": "2024-01-01T12:00:00", + } + ], + } + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock(manifest, tables_data) + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + mock_result = Mock() + mock_result.rowcount = 1 + mock_session.execute.return_value = mock_result + + # Mock repository + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + # Mock ThreadPoolExecutor (not actually used in restore_by_run_id but needed for patch) + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") as mock_insert: + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_insert.return_value = mock_stmt + + with patch("services.retention.workflow_run.restore_archived_workflow_run.cast") as mock_cast: + mock_cast.return_value = mock_result + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + result = restore.restore_by_run_id("run-123") + + assert result.success is True + assert result.restored_counts.get("workflow_runs") == 1 diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py new file mode 100644 index 0000000000..a6bc79e82b --- /dev/null +++ b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py @@ -0,0 +1,214 @@ +""" +Unit tests for services.advanced_prompt_template_service +""" + +import copy + +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) +from models.model import AppMode +from services.advanced_prompt_template_service import AdvancedPromptTemplateService + + +class TestAdvancedPromptTemplateService: + """Test suite for AdvancedPromptTemplateService.""" + + def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: + """Test baichuan model names use baichuan context prompt.""" + # Arrange + args = { + "app_mode": AppMode.CHAT, + "model_mode": "chat", + "model_name": "Baichuan2-13B", + "has_context": "true", + } + + # Act + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) + + def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: + """Test non-baichuan model names use common prompt.""" + # Arrange + args = { + "app_mode": AppMode.CHAT, + "model_mode": "completion", + "model_name": "gpt-4", + "has_context": "false", + } + original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert + assert result == original_config + assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: + """Test invalid app mode returns empty dict.""" + # Arrange + app_mode = "invalid" + model_mode = "chat" + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") + + # Assert + assert result == {} + + def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: + """Test context is prepended for completion prompt when has_context is true.""" + # Arrange + original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) + assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: + """Test context is prepended for chat prompt when has_context is true.""" + # Arrange + original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) + assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: + """Test chat prompt remains unchanged when has_context is false.""" + # Arrange + original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") + + # Assert + assert result == original_config + assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG + + def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: + """Test completion app mode with completion model returns completion prompt.""" + # Arrange + original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") + + # Assert + assert result == original_config + assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: + """Test invalid model mode returns empty dict.""" + # Arrange + app_mode = AppMode.CHAT + model_mode = "invalid" + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") + + # Assert + assert result == {} + + def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: + """Test helper keeps completion prompt unchanged when context is disabled.""" + # Arrange + prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] + + # Act + result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"] == original_text + + def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: + """Test helper keeps chat prompt unchanged when context is disabled.""" + # Arrange + prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] + + # Act + result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text + + def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: + """Test baichuan chat/completion returns the expected config.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") + + # Assert + assert result == original_config + assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: + """Test baichuan completion/chat returns the expected config.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") + + # Assert + assert result == original_config + assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: + """Test baichuan completion/completion prepends baichuan context when enabled.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) + assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: + """Test baichuan chat/chat prepends baichuan context when enabled.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) + assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: + """Test invalid baichuan mode combinations return empty dict.""" + # Arrange + app_mode = "invalid" + model_mode = "invalid" + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") + + # Assert + assert result == {} diff --git a/api/tests/unit_tests/services/test_agent_service.py b/api/tests/unit_tests/services/test_agent_service.py new file mode 100644 index 0000000000..7ce3d7ef7b --- /dev/null +++ b/api/tests/unit_tests/services/test_agent_service.py @@ -0,0 +1,346 @@ +""" +Unit tests for services.agent_service +""" + +from collections.abc import Callable +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +import pytz + +from core.plugin.impl.exc import PluginDaemonClientSideError +from models import Account +from models.model import App, Conversation, EndUser, Message, MessageAgentThought +from services.agent_service import AgentService + + +def _make_current_user_account(timezone: str = "UTC") -> Account: + account = Account(name="Test User", email="test@example.com") + account.timezone = timezone + return account + + +def _make_app_model(app_model_config: MagicMock | None) -> MagicMock: + app_model = MagicMock(spec=App) + app_model.id = "app-123" + app_model.tenant_id = "tenant-123" + app_model.app_model_config = app_model_config + return app_model + + +def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock: + conversation = MagicMock(spec=Conversation) + conversation.id = "conv-123" + conversation.app_id = "app-123" + conversation.from_end_user_id = from_end_user_id + conversation.from_account_id = from_account_id + return conversation + + +def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock: + message = MagicMock(spec=Message) + message.id = "msg-123" + message.conversation_id = "conv-123" + message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) + message.provider_response_latency = 1.23 + message.answer_tokens = 4 + message.message_tokens = 6 + message.agent_thoughts = agent_thoughts + message.message_files = ["file-a.txt"] + return message + + +def _make_agent_thought() -> MagicMock: + agent_thought = MagicMock(spec=MessageAgentThought) + agent_thought.tokens = 3 + agent_thought.tool_input = "raw-input" + agent_thought.observation = "raw-output" + agent_thought.thought = "thinking" + agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) + agent_thought.files = [] + agent_thought.tools = ["tool_a", "dataset_tool"] + agent_thought.tool_labels = {"tool_a": "Tool A"} + agent_thought.tool_meta = { + "tool_a": { + "tool_config": { + "tool_provider_type": "custom", + "tool_provider": "provider-1", + }, + "tool_parameters": {"param": "value"}, + "time_cost": 2.5, + }, + "dataset_tool": { + "tool_config": { + "tool_provider_type": "dataset-retrieval", + "tool_provider": "dataset-provider", + } + }, + } + agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}} + agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}} + return agent_thought + + +def _build_query_side_effect( + conversation: Conversation | None, + message: Message | None, + executor: EndUser | Account | None, +) -> Callable[..., MagicMock]: + def _query_side_effect(*args: object, **kwargs: object) -> MagicMock: + query = MagicMock() + query.where.return_value = query + if any(arg is Conversation for arg in args): + query.first.return_value = conversation + elif any(arg is Message for arg in args): + query.first.return_value = message + elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args): + query.first.return_value = executor + return query + + return _query_side_effect + + +class TestAgentServiceGetAgentLogs: + """Test suite for AgentService.get_agent_logs.""" + + def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None: + """Test missing conversation raises ValueError.""" + # Arrange + app_model = _make_app_model(MagicMock()) + with patch("services.agent_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, "missing-conv", "msg-1") + + def test_get_agent_logs_should_raise_when_message_missing(self) -> None: + """Test missing message raises ValueError.""" + # Arrange + app_model = _make_app_model(MagicMock()) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + with patch("services.agent_service.db") as mock_db: + conversation_query = MagicMock() + conversation_query.where.return_value = conversation_query + conversation_query.first.return_value = conversation + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = None + + mock_db.session.query.side_effect = [conversation_query, message_query] + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, "missing-msg") + + def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None: + """Test missing app model config raises ValueError.""" + # Arrange + app_model = _make_app_model(None) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + message = _make_message([]) + current_user = _make_current_user_account() + + with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, message.id) + + def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None: + """Test missing agent config raises ValueError.""" + # Arrange + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + message = _make_message([]) + current_user = _make_current_user_account() + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=None), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, message.id) + + def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None: + """Test agent logs returned for end-user executor with tool icons.""" + # Arrange + agent_thought = _make_agent_thought() + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + executor = MagicMock(spec=EndUser) + executor.name = "End User" + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_tool = MagicMock() + agent_tool.tool_name = "tool_a" + agent_tool.provider_type = "custom" + agent_tool.provider_id = "provider-2" + agent_config = MagicMock() + agent_config.tools = [agent_tool] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert, + patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon, + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) + mock_get_icon.side_effect = [None, "icon-a"] + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["status"] == "success" + assert result["meta"]["executor"] == "End User" + assert result["meta"]["total_tokens"] == 10 + assert result["meta"]["agent_mode"] == "react" + assert result["meta"]["iterations"] == 1 + assert result["files"] == ["file-a.txt"] + assert len(result["iterations"]) == 1 + tool_calls = result["iterations"][0]["tool_calls"] + assert tool_calls[0]["tool_name"] == "tool_a" + assert tool_calls[0]["tool_icon"] == "icon-a" + assert tool_calls[1]["tool_name"] == "dataset_tool" + assert tool_calls[1]["tool_icon"] == "" + mock_convert.assert_called_once() + + def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None: + """Test agent logs fall back to account executor when end user is missing.""" + # Arrange + agent_thought = _make_agent_thought() + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1") + executor = MagicMock(spec=Account) + executor.name = "Account User" + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_config = MagicMock() + agent_config.tools = [] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), + patch("services.agent_service.ToolManager.get_tool_icon", return_value=""), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["executor"] == "Account User" + + def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None: + """Test unknown executor and missing tool details fall back to defaults.""" + # Arrange + agent_thought = _make_agent_thought() + agent_thought.tool_labels = {} + agent_thought.tool_inputs_dict = {} + agent_thought.tool_outputs_dict = None + agent_thought.tool_meta = {"tool_a": {"error": "failed"}} + agent_thought.tools = ["tool_a"] + + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_config = MagicMock() + agent_config.tools = [] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), + patch("services.agent_service.ToolManager.get_tool_icon", return_value=None), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None) + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["executor"] == "Unknown" + assert result["meta"]["agent_mode"] == "react" + tool_call = result["iterations"][0]["tool_calls"][0] + assert tool_call["status"] == "error" + assert tool_call["error"] == "failed" + assert tool_call["tool_label"] == "tool_a" + assert tool_call["tool_input"] == {} + assert tool_call["tool_output"] == {} + assert tool_call["time_cost"] == 0 + assert tool_call["tool_parameters"] == {} + assert tool_call["tool_icon"] is None + + +class TestAgentServiceProviders: + """Test suite for AgentService provider methods.""" + + def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None: + """Test list_agent_providers delegates to PluginAgentClient.""" + # Arrange + tenant_id = "tenant-1" + expected = [{"name": "provider"}] + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_providers.return_value = expected + + # Act + result = AgentService.list_agent_providers("user-1", tenant_id) + + # Assert + assert result == expected + mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id) + + def test_get_agent_provider_should_return_provider_when_successful(self) -> None: + """Test get_agent_provider returns provider when successful.""" + # Arrange + tenant_id = "tenant-1" + provider_name = "provider-a" + expected = {"name": provider_name} + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_provider.return_value = expected + + # Act + result = AgentService.get_agent_provider("user-1", tenant_id, provider_name) + + # Assert + assert result == expected + mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name) + + def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None: + """Test get_agent_provider wraps PluginDaemonClientSideError into ValueError.""" + # Arrange + tenant_id = "tenant-1" + provider_name = "provider-a" + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError( + "plugin error" + ) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_provider("user-1", tenant_id, provider_name) diff --git a/api/tests/unit_tests/services/test_annotation_service.py b/api/tests/unit_tests/services/test_annotation_service.py new file mode 100644 index 0000000000..0aacfc7f13 --- /dev/null +++ b/api/tests/unit_tests/services/test_annotation_service.py @@ -0,0 +1,1685 @@ +""" +Unit tests for services.annotation_service +""" + +from io import BytesIO +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + +from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation +from services.annotation_service import AppAnnotationService + + +def _make_app(app_id: str = "app-1", tenant_id: str = "tenant-1") -> MagicMock: + app = MagicMock(spec=App) + app.id = app_id + app.tenant_id = tenant_id + app.status = "normal" + return app + + +def _make_user(user_id: str = "user-1") -> MagicMock: + user = MagicMock() + user.id = user_id + return user + + +def _make_message(message_id: str = "msg-1", app_id: str = "app-1") -> MagicMock: + message = MagicMock(spec=Message) + message.id = message_id + message.app_id = app_id + message.conversation_id = "conv-1" + message.query = "default-question" + message.annotation = None + return message + + +def _make_annotation(annotation_id: str = "ann-1") -> MagicMock: + annotation = MagicMock(spec=MessageAnnotation) + annotation.id = annotation_id + annotation.content = "" + annotation.question = "" + annotation.question_text = "" + return annotation + + +def _make_setting(setting_id: str = "setting-1", with_detail: bool = True) -> MagicMock: + setting = MagicMock(spec=AppAnnotationSetting) + setting.id = setting_id + setting.score_threshold = 0.5 + setting.collection_binding_id = "collection-1" + if with_detail: + setting.collection_binding_detail = SimpleNamespace(provider_name="provider-a", model_name="model-a") + else: + setting.collection_binding_detail = None + return setting + + +def _make_file(content: bytes) -> FileStorage: + return FileStorage(stream=BytesIO(content)) + + +class TestAppAnnotationServiceUpInsert: + """Test suite for up_insert_app_annotation_from_message.""" + + def test_up_insert_app_annotation_from_message_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.up_insert_app_annotation_from_message(args, "app-1") + + def test_up_insert_app_annotation_from_message_should_raise_value_error_when_answer_missing(self) -> None: + """Test missing answer and content raises ValueError.""" + # Arrange + args = {"message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_raise_not_found_when_message_missing(self) -> None: + """Test missing message raises NotFound.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, message_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_update_existing_annotation_when_found(self) -> None: + """Test existing annotation is updated and indexed.""" + # Arrange + args = {"answer": "updated", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + message = _make_message(message_id="msg-1", app_id=app.id) + message.annotation = annotation + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = message + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, message_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation + assert annotation.content == "updated" + assert annotation.question == message.query + mock_db.session.add.assert_called_once_with(annotation) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + message.query, + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_up_insert_app_annotation_from_message_should_create_annotation_when_message_has_no_annotation( + self, + ) -> None: + """Test new annotation is created when message has no annotation.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1", "question": "q1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + message = _make_message(message_id="msg-1", app_id=app.id) + message.annotation = None + annotation_instance = _make_annotation("ann-1") + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = message + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, message_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + conversation_id=message.conversation_id, + message_id=message.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question without message_id raises ValueError.""" + # Arrange + args = {"answer": "hello"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_create_annotation_when_message_missing(self) -> None: + """Test annotation is created when message_id is not provided.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + annotation_instance = _make_annotation("ann-1") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation_instance.id, + "q1", + tenant_id, + app.id, + setting.collection_binding_id, + ) + + +class TestAppAnnotationServiceEnableDisable: + """Test suite for enable/disable app annotation.""" + + def test_enable_app_annotation_should_return_processing_when_cache_hit(self) -> None: + """Test cache hit returns processing status.""" + # Arrange + args = {"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"} + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.enable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = "job-1" + + # Act + result = AppAnnotationService.enable_app_annotation(args, "app-1") + + # Assert + assert result == {"job_id": "job-1", "job_status": "processing"} + mock_task.delay.assert_not_called() + + def test_enable_app_annotation_should_enqueue_job_when_cache_miss(self) -> None: + """Test cache miss enqueues enable task.""" + # Arrange + args = {"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"} + current_user = _make_user("user-1") + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.uuid.uuid4", return_value="uuid-1"), + patch("services.annotation_service.enable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = None + + # Act + result = AppAnnotationService.enable_app_annotation(args, "app-1") + + # Assert + assert result == {"job_id": "uuid-1", "job_status": "waiting"} + mock_redis.setnx.assert_called_once_with("enable_app_annotation_job_uuid-1", "waiting") + mock_task.delay.assert_called_once_with( + "uuid-1", + "app-1", + current_user.id, + tenant_id, + 0.5, + "p", + "m", + ) + + def test_disable_app_annotation_should_return_processing_when_cache_hit(self) -> None: + """Test disable cache hit returns processing status.""" + # Arrange + tenant_id = "tenant-1" + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.disable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = "job-2" + + # Act + result = AppAnnotationService.disable_app_annotation("app-1") + + # Assert + assert result == {"job_id": "job-2", "job_status": "processing"} + mock_task.delay.assert_not_called() + + def test_disable_app_annotation_should_enqueue_job_when_cache_miss(self) -> None: + """Test disable cache miss enqueues disable task.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.uuid.uuid4", return_value="uuid-2"), + patch("services.annotation_service.disable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = None + + # Act + result = AppAnnotationService.disable_app_annotation("app-1") + + # Assert + assert result == {"job_id": "uuid-2", "job_status": "waiting"} + mock_redis.setnx.assert_called_once_with("disable_app_annotation_job_uuid-2", "waiting") + mock_task.delay.assert_called_once_with("uuid-2", "app-1", tenant_id) + + +class TestAppAnnotationServiceListAndExport: + """Test suite for list and export methods.""" + + def test_get_annotation_list_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_list_by_app_id("app-1", 1, 10, "") + + def test_get_annotation_list_by_app_id_should_return_items_with_keyword(self) -> None: + """Test keyword search returns items and total.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + pagination = SimpleNamespace(items=["a1"], total=1) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("libs.helper.escape_like_pattern", return_value="safe"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_list_by_app_id(app.id, 1, 10, "keyword") + + # Assert + assert items == ["a1"] + assert total == 1 + + def test_get_annotation_list_by_app_id_should_return_items_without_keyword(self) -> None: + """Test list query without keyword returns paginated items.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + pagination = SimpleNamespace(items=["a1", "a2"], total=2) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_list_by_app_id(app.id, 1, 10, "") + + # Assert + assert items == ["a1", "a2"] + assert total == 2 + + def test_export_annotation_list_by_app_id_should_sanitize_fields(self) -> None: + """Test export sanitizes question and content fields.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation1 = _make_annotation("ann-1") + annotation1.question = "=cmd" + annotation1.content = "+1" + annotation2 = _make_annotation("ann-2") + annotation2.question = "@bad" + annotation2.content = "-2" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.CSVSanitizer.sanitize_value", side_effect=lambda v: f"safe:{v}"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.order_by.return_value = annotation_query + annotation_query.all.return_value = [annotation1, annotation2] + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act + result = AppAnnotationService.export_annotation_list_by_app_id(app.id) + + # Assert + assert result == [annotation1, annotation2] + assert annotation1.question == "safe:=cmd" + assert annotation1.content == "safe:+1" + assert annotation2.question == "safe:@bad" + assert annotation2.content == "safe:-2" + + def test_export_annotation_list_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test export raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.export_annotation_list_by_app_id("app-1") + + +class TestAppAnnotationServiceDirectManipulation: + """Test suite for direct insert/update/delete methods.""" + + def test_insert_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: + """Test insert raises NotFound when app is missing.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.insert_app_annotation_directly(args, "app-1") + + def test_insert_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question raises ValueError.""" + # Arrange + args = {"answer": "hello"} + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.insert_app_annotation_directly(args, app.id) + + def test_insert_app_annotation_directly_should_create_annotation_and_index(self) -> None: + """Test insert creates annotation and triggers index task.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + current_user = _make_user("user-1") + tenant_id = "tenant-1" + app = _make_app() + annotation_instance = _make_annotation("ann-1") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.insert_app_annotation_directly(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation_instance.id, + "q1", + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_update_app_annotation_directly_should_raise_not_found_when_annotation_missing(self) -> None: + """Test missing annotation raises NotFound.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1") + + def test_update_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound in update path.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1") + + def test_update_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question raises ValueError.""" + # Arrange + args = {"answer": "hello"} + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + + def test_update_app_annotation_directly_should_update_annotation_and_index(self) -> None: + """Test update changes fields and triggers index update.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + annotation.question_text = "q1" + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.update_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + + # Assert + assert result == annotation + assert annotation.content == "hello" + assert annotation.question == "q1" + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + annotation.question_text, + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_delete_app_annotation_should_delete_annotation_and_histories(self) -> None: + """Test delete removes annotation and hit histories.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + history1 = MagicMock(spec=AppAnnotationHitHistory) + history2 = MagicMock(spec=AppAnnotationHitHistory) + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + scalars_result = MagicMock() + scalars_result.all.return_value = [history1, history2] + + mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] + mock_db.session.scalars.return_value = scalars_result + + # Act + AppAnnotationService.delete_app_annotation(app.id, annotation.id) + + # Assert + mock_db.session.delete.assert_any_call(annotation) + mock_db.session.delete.assert_any_call(history1) + mock_db.session.delete.assert_any_call(history2) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + app.id, + tenant_id, + setting.collection_binding_id, + ) + + def test_delete_app_annotation_should_raise_not_found_when_app_missing(self) -> None: + """Test delete raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotation("app-1", "ann-1") + + def test_delete_app_annotation_should_raise_not_found_when_annotation_missing(self) -> None: + """Test delete raises NotFound when annotation is missing.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotation(app.id, "ann-1") + + def test_delete_app_annotations_in_batch_should_return_zero_when_none_found(self) -> None: + """Test batch delete returns zero when no annotations found.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotations_query = MagicMock() + annotations_query.outerjoin.return_value = annotations_query + annotations_query.where.return_value = annotations_query + annotations_query.all.return_value = [] + + mock_db.session.query.side_effect = [app_query, annotations_query] + + # Act + result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1"]) + + # Assert + assert result == {"deleted_count": 0} + + def test_delete_app_annotations_in_batch_should_raise_not_found_when_app_missing(self) -> None: + """Test batch delete raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotations_in_batch("app-1", ["ann-1"]) + + def test_delete_app_annotations_in_batch_should_delete_annotations_and_histories(self) -> None: + """Test batch delete removes annotations and triggers index deletion.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation1 = _make_annotation("ann-1") + annotation2 = _make_annotation("ann-2") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotations_query = MagicMock() + annotations_query.outerjoin.return_value = annotations_query + annotations_query.where.return_value = annotations_query + annotations_query.all.return_value = [(annotation1, setting), (annotation2, None)] + + hit_history_query = MagicMock() + hit_history_query.where.return_value = hit_history_query + hit_history_query.delete.return_value = None + + delete_query = MagicMock() + delete_query.where.return_value = delete_query + delete_query.delete.return_value = 2 + + mock_db.session.query.side_effect = [app_query, annotations_query, hit_history_query, delete_query] + + # Act + result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1", "ann-2"]) + + # Assert + assert result == {"deleted_count": 2} + mock_task.delay.assert_called_once_with(annotation1.id, app.id, tenant_id, setting.collection_binding_id) + mock_db.session.commit.assert_called_once() + + +class TestAppAnnotationServiceBatchImport: + """Test suite for batch import.""" + + def test_batch_import_app_annotations_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.batch_import_app_annotations("app-1", file) + + def test_batch_import_app_annotations_should_return_error_when_columns_invalid(self) -> None: + """Test invalid column count returns error message.""" + # Arrange + file = _make_file(b"question\nq\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["only"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Invalid CSV format" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_file_empty(self) -> None: + """Test empty file returns validation error before CSV parsing.""" + # Arrange + file = _make_file(b"") + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "empty or invalid" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_min_records_not_met(self) -> None: + """Test min records validation returns error message.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=2), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_row_limit_exceeded(self) -> None: + """Test row count over max limit returns explicit error.""" + # Arrange + file = _make_file(b"question,answer\nq1,a1\nq2,a2\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1", "q2"], "a": ["a1", "a2"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=1, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "too many records" in error_msg + + def test_batch_import_app_annotations_should_skip_malformed_rows_and_fail_min_records(self) -> None: + """Test malformed row extraction is skipped and can fail min record validation.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + malformed_row = MagicMock() + malformed_row.iloc.__getitem__.side_effect = IndexError() + df = MagicMock() + df.columns = ["q", "a"] + df.iterrows.return_value = [(0, malformed_row)] + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_skip_nan_rows_and_fail_min_records(self) -> None: + """Test NaN rows are skipped by validation and reported via min record check.""" + # Arrange + file = _make_file(b"question,answer\nnan,nan\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["nan"], "a": ["nan"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_question_too_long(self) -> None: + """Test oversized question is rejected with row context.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q" * 2001], "a": ["a"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Question at row" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_answer_too_long(self) -> None: + """Test oversized answer is rejected with row context.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q"], "a": ["a" * 10001]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Answer at row" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_quota_exceeded(self) -> None: + """Test quota validation returns error message.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True), + annotation_quota_limit=SimpleNamespace(limit=1, size=1), + ) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "exceeds the limit" in error_msg + + def test_batch_import_app_annotations_should_enqueue_job_when_valid(self) -> None: + """Test successful batch import enqueues job and returns status.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch("services.annotation_service.batch_import_annotations_task") as mock_task, + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.uuid.uuid4", return_value="uuid-3"), + patch("services.annotation_service.naive_utc_now", return_value=SimpleNamespace(timestamp=lambda: 1)), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + assert result == {"job_id": "uuid-3", "job_status": "waiting", "record_count": 1} + mock_redis.zadd.assert_called_once() + mock_redis.expire.assert_called_once() + mock_redis.setnx.assert_called_once_with("app_annotation_batch_import_uuid-3", "waiting") + mock_task.delay.assert_called_once() + + def test_batch_import_app_annotations_should_cleanup_active_job_on_unexpected_exception(self) -> None: + """Test unexpected runtime errors trigger cleanup and return wrapped error.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.uuid.uuid4", return_value="uuid-4"), + patch("services.annotation_service.naive_utc_now", return_value=SimpleNamespace(timestamp=lambda: 1)), + patch("services.annotation_service.logger") as mock_logger, + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_redis.zadd.side_effect = RuntimeError("boom") + mock_redis.zrem.side_effect = RuntimeError("cleanup-failed") + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + assert result["error_msg"] == "An error occurred while processing the file: boom" + mock_redis.zrem.assert_called_once_with(f"annotation_import_active:{tenant_id}", "uuid-4") + mock_logger.debug.assert_called_once() + + +class TestAppAnnotationServiceHitHistoryAndSettings: + """Test suite for hit history and settings methods.""" + + def test_get_annotation_hit_histories_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_hit_histories("app-1", "ann-1", 1, 10) + + def test_get_annotation_hit_histories_should_return_items_and_total(self) -> None: + """Test hit histories pagination returns items and total.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + pagination = SimpleNamespace(items=["h1"], total=2) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_hit_histories(app.id, annotation.id, 1, 10) + + # Assert + assert items == ["h1"] + assert total == 2 + + def test_get_annotation_hit_histories_should_raise_not_found_when_annotation_missing(self) -> None: + """Test missing annotation raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_hit_histories(app.id, "ann-1", 1, 10) + + def test_get_annotation_by_id_should_return_none_when_missing(self) -> None: + """Test get_annotation_by_id returns None when not found.""" + # Arrange + with patch("services.annotation_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act + result = AppAnnotationService.get_annotation_by_id("ann-1") + + # Assert + assert result is None + + def test_get_annotation_by_id_should_return_annotation_when_exists(self) -> None: + """Test get_annotation_by_id returns annotation when found.""" + # Arrange + annotation = _make_annotation("ann-1") + with patch("services.annotation_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = annotation + mock_db.session.query.return_value = query + + # Act + result = AppAnnotationService.get_annotation_by_id("ann-1") + + # Assert + assert result == annotation + + def test_add_annotation_history_should_update_hit_count_and_store_history(self) -> None: + """Test add_annotation_history updates hit count and creates history.""" + # Arrange + with ( + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.AppAnnotationHitHistory") as mock_history_cls, + ): + query = MagicMock() + query.where.return_value = query + mock_db.session.query.return_value = query + + # Act + AppAnnotationService.add_annotation_history( + annotation_id="ann-1", + app_id="app-1", + annotation_question="q", + annotation_content="a", + query="q", + user_id="user-1", + message_id="msg-1", + from_source="chat", + score=0.8, + ) + + # Assert + query.update.assert_called_once() + mock_history_cls.assert_called_once() + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_get_app_annotation_setting_by_app_id_should_return_embedding_model_when_detail_exists(self) -> None: + """Test setting detail returns embedding model info.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting(with_detail=True) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result["enabled"] is True + embedding_model = cast(dict[str, Any], result["embedding_model"]) + assert embedding_model["embedding_provider_name"] == "provider-a" + assert embedding_model["embedding_model_name"] == "model-a" + + def test_get_app_annotation_setting_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_app_annotation_setting_by_app_id("app-1") + + def test_get_app_annotation_setting_by_app_id_should_return_empty_embedding_model_when_no_detail(self) -> None: + """Test setting without detail returns empty embedding model.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting(with_detail=False) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result["enabled"] is True + assert result["embedding_model"] == {} + + def test_get_app_annotation_setting_by_app_id_should_return_disabled_when_setting_missing(self) -> None: + """Test missing setting returns disabled payload.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result == {"enabled": False} + + def test_update_app_annotation_setting_should_update_and_return_detail(self) -> None: + """Test update_app_annotation_setting updates fields and returns detail.""" + # Arrange + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + setting = _make_setting(with_detail=True) + args = {"score_threshold": 0.8} + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.naive_utc_now", return_value="now"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) + + # Assert + assert result["enabled"] is True + assert result["score_threshold"] == 0.8 + embedding_model = cast(dict[str, Any], result["embedding_model"]) + assert embedding_model["embedding_provider_name"] == "provider-a" + mock_db.session.add.assert_called_once_with(setting) + mock_db.session.commit.assert_called_once() + + def test_update_app_annotation_setting_should_return_empty_embedding_model_when_detail_missing(self) -> None: + """Test update returns empty embedding_model when collection detail is absent.""" + # Arrange + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + setting = _make_setting(with_detail=False) + args = {"score_threshold": 0.7} + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.naive_utc_now", return_value="now"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) + + # Assert + assert result["enabled"] is True + assert result["score_threshold"] == 0.7 + assert result["embedding_model"] == {} + + def test_update_app_annotation_setting_should_raise_not_found_when_app_missing(self) -> None: + """Test update raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_setting("app-1", "setting-1", {"score_threshold": 0.5}) + + def test_update_app_annotation_setting_should_raise_not_found_when_setting_missing(self) -> None: + """Test update raises NotFound when setting is missing.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_setting(app.id, "setting-1", {"score_threshold": 0.5}) + + +class TestAppAnnotationServiceClearAll: + """Test suite for clear_all_annotations.""" + + def test_clear_all_annotations_should_delete_annotations_and_histories(self) -> None: + """Test clear_all_annotations deletes all data and triggers index removal.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting() + annotation1 = _make_annotation("ann-1") + annotation2 = _make_annotation("ann-2") + history = MagicMock(spec=AppAnnotationHitHistory) + + def query_side_effect(*args: object, **kwargs: object) -> MagicMock: + query = MagicMock() + query.where.return_value = query + if App in args: + query.first.return_value = app + elif AppAnnotationSetting in args: + query.first.return_value = setting + elif MessageAnnotation in args: + query.yield_per.return_value = [annotation1, annotation2] + elif AppAnnotationHitHistory in args: + query.yield_per.return_value = [history] + return query + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + mock_db.session.query.side_effect = query_side_effect + + # Act + result = AppAnnotationService.clear_all_annotations(app.id) + + # Assert + assert result == {"result": "success"} + mock_db.session.delete.assert_any_call(annotation1) + mock_db.session.delete.assert_any_call(annotation2) + mock_db.session.delete.assert_any_call(history) + mock_task.delay.assert_any_call(annotation1.id, app.id, tenant_id, setting.collection_binding_id) + mock_task.delay.assert_any_call(annotation2.id, app.id, tenant_id, setting.collection_binding_id) + mock_db.session.commit.assert_called_once() + + def test_clear_all_annotations_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.clear_all_annotations("app-1") diff --git a/api/tests/unit_tests/services/test_api_based_extension_service.py b/api/tests/unit_tests/services/test_api_based_extension_service.py new file mode 100644 index 0000000000..7f4b5fdaa3 --- /dev/null +++ b/api/tests/unit_tests/services/test_api_based_extension_service.py @@ -0,0 +1,421 @@ +""" +Comprehensive unit tests for services/api_based_extension_service.py + +Covers: +- APIBasedExtensionService.get_all_by_tenant_id +- APIBasedExtensionService.save +- APIBasedExtensionService.delete +- APIBasedExtensionService.get_with_tenant_id +- APIBasedExtensionService._validation (new record & existing record branches) +- APIBasedExtensionService._ping_connection (pong success, wrong response, exception) +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from services.api_based_extension_service import APIBasedExtensionService + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_extension( + *, + id_: str | None = None, + tenant_id: str = "tenant-001", + name: str = "my-ext", + api_endpoint: str = "https://example.com/hook", + api_key: str = "secret-key-123", +) -> MagicMock: + """Return a lightweight mock that mimics APIBasedExtension.""" + ext = MagicMock() + ext.id = id_ + ext.tenant_id = tenant_id + ext.name = name + ext.api_endpoint = api_endpoint + ext.api_key = api_key + return ext + + +# --------------------------------------------------------------------------- +# Tests: get_all_by_tenant_id +# --------------------------------------------------------------------------- + + +class TestGetAllByTenantId: + """Tests for APIBasedExtensionService.get_all_by_tenant_id.""" + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt): + """Each api_key is decrypted and the list is returned.""" + ext1 = _make_extension(id_="id-1", api_key="enc-key-1") + ext2 = _make_extension(id_="id-2", api_key="enc-key-2") + + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [ + ext1, + ext2, + ] + + result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") + + assert result == [ext1, ext2] + assert ext1.api_key == "decrypted-key" + assert ext2.api_key == "decrypted-key" + assert mock_decrypt.call_count == 2 + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt): + """Returns an empty list gracefully when no records exist.""" + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") + + assert result == [] + mock_decrypt.assert_not_called() + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt): + """Verifies the DB is queried with the supplied tenant_id.""" + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz") + + mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz") + + +# --------------------------------------------------------------------------- +# Tests: save +# --------------------------------------------------------------------------- + + +class TestSave: + """Tests for APIBasedExtensionService.save.""" + + @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") + @patch("services.api_based_extension_service.db") + @patch.object(APIBasedExtensionService, "_validation") + def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt): + """Happy path: validation passes, key is encrypted, record is added and committed.""" + ext = _make_extension(id_=None, api_key="plain-key-123") + + result = APIBasedExtensionService.save(ext) + + mock_validation.assert_called_once_with(ext) + mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123") + assert ext.api_key == "encrypted-key" + mock_db.session.add.assert_called_once_with(ext) + mock_db.session.commit.assert_called_once() + assert result is ext + + @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") + @patch("services.api_based_extension_service.db") + @patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty")) + def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt): + """If _validation raises, save should propagate the error without touching the DB.""" + ext = _make_extension(name="") + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(ext) + + mock_db.session.add.assert_not_called() + mock_db.session.commit.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: delete +# --------------------------------------------------------------------------- + + +class TestDelete: + """Tests for APIBasedExtensionService.delete.""" + + @patch("services.api_based_extension_service.db") + def test_delete_removes_record_and_commits(self, mock_db): + """delete() must call session.delete with the extension and then commit.""" + ext = _make_extension(id_="delete-me") + + APIBasedExtensionService.delete(ext) + + mock_db.session.delete.assert_called_once_with(ext) + mock_db.session.commit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Tests: get_with_tenant_id +# --------------------------------------------------------------------------- + + +class TestGetWithTenantId: + """Tests for APIBasedExtensionService.get_with_tenant_id.""" + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt): + """Found extension has its api_key decrypted before being returned.""" + ext = _make_extension(id_="ext-123", api_key="enc-key") + + (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext + + result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123") + + assert result is ext + assert ext.api_key == "decrypted-key" + mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key") + + @patch("services.api_based_extension_service.db") + def test_raises_value_error_when_not_found(self, mock_db): + """Raises ValueError when no matching extension exists.""" + (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None + + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent") + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt): + """Verifies both tenant_id and extension id are used in the query.""" + ext = _make_extension(id_="ext-abc") + chain = mock_db.session.query.return_value + chain.filter_by.return_value.filter_by.return_value.first.return_value = ext + + APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc") + + # First filter_by call uses tenant_id + chain.filter_by.assert_called_once_with(tenant_id="tenant-002") + # Second filter_by call uses id + chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc") + + +# --------------------------------------------------------------------------- +# Tests: _validation (new record — id is falsy) +# --------------------------------------------------------------------------- + + +class TestValidationNewRecord: + """Tests for _validation() with a brand-new record (no id).""" + + def _build_mock_db(self, name_exists: bool = False): + mock_db = MagicMock() + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( + MagicMock() if name_exists else None + ) + return mock_db + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_valid_new_extension_passes(self, mock_db, mock_ping): + """A new record with all valid fields should pass without exceptions.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey") + + # Should not raise + APIBasedExtensionService._validation(ext) + mock_ping.assert_called_once_with(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_is_empty(self, mock_db): + """Empty name raises ValueError.""" + ext = _make_extension(id_=None, name="") + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_is_none(self, mock_db): + """None name raises ValueError.""" + ext = _make_extension(id_=None, name=None) + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_already_exists_for_new_record(self, mock_db): + """A new record whose name already exists raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( + MagicMock() + ) + ext = _make_extension(id_=None, name="duplicate-name") + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_endpoint_is_empty(self, mock_db): + """Empty api_endpoint raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_endpoint="") + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_endpoint_is_none(self, mock_db): + """None api_endpoint raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_endpoint=None) + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_is_empty(self, mock_db): + """Empty api_key raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="") + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_is_none(self, mock_db): + """None api_key raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key=None) + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_too_short(self, mock_db): + """api_key shorter than 5 characters raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="abc") + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_exactly_four_chars(self, mock_db): + """api_key with exactly 4 characters raises ValueError (boundary condition).""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="1234") + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService._validation(ext) + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping): + """api_key with exactly 5 characters should pass (boundary condition).""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="12345") + + # Should not raise + APIBasedExtensionService._validation(ext) + + +# --------------------------------------------------------------------------- +# Tests: _validation (existing record — id is truthy) +# --------------------------------------------------------------------------- + + +class TestValidationExistingRecord: + """Tests for _validation() with an existing record (id is set).""" + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_valid_existing_extension_passes(self, mock_db, mock_ping): + """An existing record whose name is unique (excluding self) should pass.""" + # .where(...).first() → None means no *other* record has that name + ( + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value + ) = None + ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey") + + # Should not raise + APIBasedExtensionService._validation(ext) + mock_ping.assert_called_once_with(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db): + """Existing record cannot use a name already owned by a different record.""" + ( + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value + ) = MagicMock() + ext = _make_extension(id_="existing-id", name="taken-name") + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService._validation(ext) + + +# --------------------------------------------------------------------------- +# Tests: _ping_connection +# --------------------------------------------------------------------------- + + +class TestPingConnection: + """Tests for APIBasedExtensionService._ping_connection.""" + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_successful_ping_returns_pong(self, mock_requestor_class): + """When the endpoint returns {"result": "pong"}, no exception is raised.""" + mock_client = MagicMock() + mock_client.request.return_value = {"result": "pong"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key") + # Should not raise + APIBasedExtensionService._ping_connection(ext) + + mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_wrong_ping_response_raises_value_error(self, mock_requestor_class): + """When the response is not {"result": "pong"}, a ValueError is raised.""" + mock_client = MagicMock() + mock_client.request.return_value = {"result": "error"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_network_exception_wraps_in_value_error(self, mock_requestor_class): + """Any exception raised during request is wrapped in a ValueError.""" + mock_client = MagicMock() + mock_client.request.side_effect = ConnectionError("network failure") + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class): + """Exception raised by the requestor constructor itself is wrapped.""" + mock_requestor_class.side_effect = RuntimeError("bad config") + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_missing_result_key_raises_value_error(self, mock_requestor_class): + """A response dict without a 'result' key does not equal 'pong' → raises.""" + mock_client = MagicMock() + mock_client.request.return_value = {} # no 'result' key + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_uses_ping_extension_point(self, mock_requestor_class): + """The PING extension point is passed to the client.request call.""" + from models.api_based_extension import APIBasedExtensionPoint + + mock_client = MagicMock() + mock_client.request.return_value = {"result": "pong"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + APIBasedExtensionService._ping_connection(ext) + + call_kwargs = mock_client.request.call_args + assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING + assert call_kwargs.kwargs["params"] == {} diff --git a/api/tests/unit_tests/services/test_api_token_service.py b/api/tests/unit_tests/services/test_api_token_service.py new file mode 100644 index 0000000000..ad4de93b25 --- /dev/null +++ b/api/tests/unit_tests/services/test_api_token_service.py @@ -0,0 +1,466 @@ +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Unauthorized + +import services.api_token_service as api_token_service_module +from services.api_token_service import ApiTokenCache, CachedApiToken + + +@pytest.fixture +def mock_db_session(): + """Fixture providing common DB session mocking for query_token_from_db tests.""" + fake_engine = MagicMock() + + session = MagicMock() + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + with ( + patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)), + patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class, + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + yield { + "session": session, + "mock_session_class": mock_session_class, + "mock_cache_set": mock_cache_set, + "mock_record_usage": mock_record_usage, + "fake_engine": fake_engine, + } + + +class TestQueryTokenFromDb: + def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session): + """Test DB lookup success path caches token and records usage.""" + # Arrange + auth_token = "token-123" + scope = "app" + api_token = MagicMock() + + mock_db_session["session"].scalar.return_value = api_token + + # Act + result = api_token_service_module.query_token_from_db(auth_token, scope) + + # Assert + assert result == api_token + mock_db_session["mock_session_class"].assert_called_once_with( + mock_db_session["fake_engine"], expire_on_commit=False + ) + mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token) + mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope) + + def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session): + """Test DB lookup miss path caches null marker and raises Unauthorized.""" + # Arrange + auth_token = "missing-token" + scope = "app" + + mock_db_session["session"].scalar.return_value = None + + # Act / Assert + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.query_token_from_db(auth_token, scope) + + mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None) + mock_db_session["mock_record_usage"].assert_not_called() + + +class TestRecordTokenUsage: + def test_should_write_active_key_with_iso_timestamp_and_ttl(self): + """Test record_token_usage writes usage timestamp with one-hour TTL.""" + # Arrange + auth_token = "token-123" + scope = "dataset" + fixed_time = datetime(2026, 2, 24, 12, 0, 0) + expected_key = ApiTokenCache.make_active_key(auth_token, scope) + + with ( + patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time), + patch.object(api_token_service_module, "redis_client") as mock_redis, + ): + # Act + api_token_service_module.record_token_usage(auth_token, scope) + + # Assert + mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600) + + def test_should_not_raise_when_redis_write_fails(self): + """Test record_token_usage swallows Redis errors.""" + # Arrange + with patch.object(api_token_service_module, "redis_client") as mock_redis: + mock_redis.set.side_effect = Exception("redis unavailable") + + # Act / Assert + api_token_service_module.record_token_usage("token-123", "app") + + +class TestFetchTokenWithSingleFlight: + def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self): + """Test single-flight returns cache when another request already populated it.""" + # Arrange + auth_token = "token-123" + scope = "app" + cached_token = CachedApiToken( + id="id-1", + app_id="app-1", + tenant_id="tenant-1", + type="app", + token=auth_token, + last_used_at=None, + created_at=None, + ) + + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get, + patch.object(api_token_service_module, "query_token_from_db") as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == cached_token + mock_redis.lock.assert_called_once_with( + f"api_token_query_lock:{scope}:{auth_token}", + timeout=10, + blocking_timeout=5, + ) + lock.acquire.assert_called_once_with(blocking=True) + lock.release.assert_called_once() + mock_cache_get.assert_called_once_with(auth_token, scope) + mock_query_db.assert_not_called() + + def test_should_query_db_when_lock_acquired_and_cache_missed(self): + """Test single-flight queries DB when cache remains empty after lock acquisition.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None), + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_query_db.assert_called_once_with(auth_token, scope) + lock.release.assert_called_once() + + def test_should_query_db_directly_when_lock_not_acquired(self): + """Test lock timeout branch falls back to direct DB query.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.return_value = False + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get") as mock_cache_get, + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_cache_get.assert_not_called() + mock_query_db.assert_called_once_with(auth_token, scope) + lock.release.assert_not_called() + + def test_should_reraise_unauthorized_from_db_query(self): + """Test Unauthorized from DB query is propagated unchanged.""" + # Arrange + auth_token = "token-123" + scope = "app" + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None), + patch.object( + api_token_service_module, + "query_token_from_db", + side_effect=Unauthorized("Access token is invalid"), + ), + ): + mock_redis.lock.return_value = lock + + # Act / Assert + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + lock.release.assert_called_once() + + def test_should_fallback_to_db_query_when_lock_raises_exception(self): + """Test Redis lock errors fall back to direct DB query.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.side_effect = RuntimeError("redis lock error") + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_query_db.assert_called_once_with(auth_token, scope) + + +class TestApiTokenCacheTenantBranches: + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis): + """Test scoped delete removes cache key and tenant index membership.""" + # Arrange + token = "token-123" + scope = "app" + cache_key = ApiTokenCache._make_cache_key(token, scope) + cached_token = CachedApiToken( + id="id-1", + app_id="app-1", + tenant_id="tenant-1", + type="app", + token=token, + last_used_at=None, + created_at=None, + ) + mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8") + + with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index: + # Act + result = ApiTokenCache.delete(token, scope) + + # Assert + assert result is True + mock_redis.delete.assert_called_once_with(cache_key) + mock_remove_index.assert_called_once_with("tenant-1", cache_key) + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis): + """Test tenant invalidation deletes indexed cache entries and index key.""" + # Arrange + tenant_id = "tenant-1" + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + mock_redis.smembers.return_value = { + b"api_token:app:token-1", + b"api_token:any:token-2", + } + + # Act + result = ApiTokenCache.invalidate_by_tenant(tenant_id) + + # Assert + assert result is True + mock_redis.smembers.assert_called_once_with(index_key) + mock_redis.delete.assert_any_call("api_token:app:token-1") + mock_redis.delete.assert_any_call("api_token:any:token-2") + mock_redis.delete.assert_any_call(index_key) + + +class TestApiTokenCacheCoreBranches: + def test_cached_api_token_repr_should_include_id_and_type(self): + """Test CachedApiToken __repr__ includes key identity fields.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + + assert repr(token) == "" + + def test_serialize_token_should_handle_cached_api_token_instances(self): + """Test serialization path when input is already a CachedApiToken.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + + serialized = ApiTokenCache._serialize_token(token) + + assert isinstance(serialized, bytes) + assert b'"id":"id-123"' in serialized + assert b'"token":"token-123"' in serialized + + def test_deserialize_token_should_return_none_for_null_markers(self): + """Test null cache marker deserializes to None.""" + assert ApiTokenCache._deserialize_token("null") is None + assert ApiTokenCache._deserialize_token(b"null") is None + + def test_deserialize_token_should_return_none_for_invalid_payload(self): + """Test invalid serialized payload returns None.""" + assert ApiTokenCache._deserialize_token("not-json") is None + + @patch("services.api_token_service.redis_client") + def test_get_should_return_none_on_cache_miss(self, mock_redis): + """Test cache miss branch in ApiTokenCache.get.""" + mock_redis.get.return_value = None + + result = ApiTokenCache.get("token-123", "app") + + assert result is None + mock_redis.get.assert_called_once_with("api_token:app:token-123") + + @patch("services.api_token_service.redis_client") + def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis): + """Test cache hit branch in ApiTokenCache.get.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + mock_redis.get.return_value = token.model_dump_json().encode("utf-8") + + result = ApiTokenCache.get("token-123", "app") + + assert isinstance(result, CachedApiToken) + assert result.id == "id-123" + + @patch("services.api_token_service.redis_client") + def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): + """Test tenant index update exits early for missing tenant id.""" + ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123") + + mock_redis.sadd.assert_not_called() + mock_redis.expire.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis): + """Test tenant index update handles Redis write errors gracefully.""" + mock_redis.sadd.side_effect = Exception("redis down") + + ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123") + + mock_redis.sadd.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): + """Test tenant index removal exits early for missing tenant id.""" + ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123") + + mock_redis.srem.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis): + """Test tenant index removal handles Redis errors gracefully.""" + mock_redis.srem.side_effect = Exception("redis down") + + ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123") + + mock_redis.srem.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis): + """Test set returns False when Redis setex fails.""" + mock_redis.setex.side_effect = Exception("redis write failed") + api_token = MagicMock() + api_token.id = "id-123" + api_token.app_id = "app-123" + api_token.tenant_id = "tenant-123" + api_token.type = "app" + api_token.token = "token-123" + api_token.last_used_at = None + api_token.created_at = None + + result = ApiTokenCache.set("token-123", "app", api_token) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis): + """Test delete(scope=None) returns False when scan_iter raises.""" + mock_redis.scan_iter.side_effect = Exception("scan failed") + + result = ApiTokenCache.delete("token-123", None) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis): + """Test scoped delete still succeeds when tenant lookup from cache fails.""" + token = "token-123" + scope = "app" + cache_key = ApiTokenCache._make_cache_key(token, scope) + mock_redis.get.side_effect = Exception("get failed") + + result = ApiTokenCache.delete(token, scope) + + assert result is True + mock_redis.delete.assert_called_once_with(cache_key) + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis): + """Test scoped delete returns False when delete operation fails.""" + token = "token-123" + scope = "app" + mock_redis.get.return_value = None + mock_redis.delete.side_effect = Exception("delete failed") + + result = ApiTokenCache.delete(token, scope) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis): + """Test tenant invalidation returns True when tenant index is empty.""" + mock_redis.smembers.return_value = set() + + result = ApiTokenCache.invalidate_by_tenant("tenant-123") + + assert result is True + mock_redis.delete.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis): + """Test tenant invalidation returns False when Redis operation fails.""" + mock_redis.smembers.side_effect = Exception("redis failed") + + result = ApiTokenCache.invalidate_by_tenant("tenant-123") + + assert result is False diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py new file mode 100644 index 0000000000..33d26f4bcb --- /dev/null +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -0,0 +1,913 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import yaml + +from dify_graph.enums import NodeType +from models import Account, AppMode +from models.model import IconType +from services import app_dsl_service +from services.app_dsl_service import ( + AppDslService, + CheckDependenciesPendingData, + ImportMode, + ImportStatus, + PendingData, + _check_version_compatibility, +) + + +class _FakeHttpResponse: + def __init__(self, content: bytes, *, raises: Exception | None = None): + self.content = content + self._raises = raises + + def raise_for_status(self) -> None: + if self._raises is not None: + raise self._raises + + +def _account_mock(*, tenant_id: str = "tenant-1", account_id: str = "account-1") -> MagicMock: + account = MagicMock(spec=Account) + account.current_tenant_id = tenant_id + account.id = account_id + return account + + +def _yaml_dump(data: dict) -> str: + return yaml.safe_dump(data, allow_unicode=True) + + +def _workflow_yaml(*, version: str = app_dsl_service.CURRENT_DSL_VERSION) -> str: + return _yaml_dump( + { + "version": version, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + } + ) + + +def test_check_version_compatibility_invalid_version_returns_failed(): + assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED + + +def test_check_version_compatibility_newer_version_returns_pending(): + assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING + + +def test_check_version_compatibility_major_older_returns_pending(monkeypatch): + monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0") + assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING + + +def test_check_version_compatibility_minor_older_returns_completed_with_warnings(): + assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS + + +def test_check_version_compatibility_equal_returns_completed(): + assert _check_version_compatibility(app_dsl_service.CURRENT_DSL_VERSION) == ImportStatus.COMPLETED + + +def test_import_app_invalid_import_mode_raises_value_error(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Invalid import_mode"): + service.import_app(account=_account_mock(), import_mode="invalid-mode", yaml_content="version: '0.1.0'") + + +def test_import_app_yaml_url_requires_url(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url=None) + assert result.status == ImportStatus.FAILED + assert "yaml_url is required" in result.error + + +def test_import_app_yaml_content_requires_content(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=None) + assert result.status == ImportStatus.FAILED + assert "yaml_content is required" in result.error + + +def test_import_app_yaml_url_fetch_error_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "Error fetching YAML from URL: boom" in result.error + + +def test_import_app_yaml_url_empty_content_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + return _FakeHttpResponse(b"") + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "Empty content" in result.error + + +def test_import_app_yaml_url_file_too_large_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + return _FakeHttpResponse(b"x" * (app_dsl_service.DSL_MAX_SIZE + 1)) + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "File size exceeds" in result.error + + +def test_import_app_yaml_not_mapping_returns_failed(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="[]") + assert result.status == ImportStatus.FAILED + assert "content must be a mapping" in result.error + + +def test_import_app_version_not_str_returns_failed(): + service = AppDslService(MagicMock()) + yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content) + assert result.status == ImportStatus.FAILED + assert "Invalid version type" in result.error + + +def test_import_app_missing_app_data_returns_failed(): + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump({"version": "0.6.0", "kind": "app"}), + ) + assert result.status == ImportStatus.FAILED + assert "Missing app data" in result.error + + +def test_import_app_app_id_not_found_returns_failed(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + session = MagicMock() + session.scalar.return_value = None + service = AppDslService(session) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id="missing-app", + ) + assert result.status == ImportStatus.FAILED + assert result.error == "App not found" + + +def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + existing_app = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value) + + session = MagicMock() + session.scalar.return_value = existing_app + service = AppDslService(session) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id="app-1", + ) + assert result.status == ImportStatus.FAILED + assert "Only workflow or advanced chat apps" in result.error + + +def test_import_app_pending_stores_import_info_in_redis(): + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(version="99.0.0"), + name="n", + description="d", + icon_type="emoji", + icon="i", + icon_background="#000000", + ) + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + + app_dsl_service.redis_client.setex.assert_called_once() + call = app_dsl_service.redis_client.setex.call_args + redis_key = call.args[0] + assert redis_key.startswith(app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX) + + +def test_import_app_completed_uses_declared_dependencies(monkeypatch): + dependencies_payload = [{"id": "langgenius/google", "version": "1.0.0"}] + + plugin_deps = [SimpleNamespace(model_dump=lambda: dependencies_payload[0])] + monkeypatch.setattr( + app_dsl_service.PluginDependency, + "model_validate", + lambda d: plugin_deps[0], + ) + + created_app = SimpleNamespace(id="app-new", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + draft_var_service = MagicMock() + monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump( + { + "version": app_dsl_service.CURRENT_DSL_VERSION, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + "dependencies": dependencies_payload, + } + ), + ) + + assert result.status == ImportStatus.COMPLETED + assert result.app_id == "app-new" + draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-new") + + +@pytest.mark.parametrize("has_workflow", [True, False]) +def test_import_app_legacy_versions_extract_dependencies(monkeypatch, has_workflow: bool): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow_graph", + lambda *_args, **_kwargs: ["from-workflow"], + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["from-model-config"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_latest_dependencies", + lambda deps: [SimpleNamespace(model_dump=lambda: {"dep": deps[0]})], + ) + + created_app = SimpleNamespace(id="app-legacy", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + draft_var_service = MagicMock() + monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service) + + data: dict = { + "version": "0.1.5", + "kind": "app", + "app": {"name": "Legacy", "mode": AppMode.WORKFLOW.value}, + } + if has_workflow: + data["workflow"] = {"graph": {"nodes": []}, "features": {}} + else: + data["model_config"] = {"model": {"provider": "openai"}} + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_yaml_dump(data) + ) + assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS + draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-legacy") + + +def test_import_app_yaml_error_returns_failed(monkeypatch): + def bad_safe_load(_content: str): + raise yaml.YAMLError("bad") + + monkeypatch.setattr(app_dsl_service.yaml, "safe_load", bad_safe_load) + + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="x: y") + assert result.status == ImportStatus.FAILED + assert result.error.startswith("Invalid YAML format:") + + +def test_import_app_unexpected_error_returns_failed(monkeypatch): + monkeypatch.setattr( + AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("oops")) + ) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_workflow_yaml() + ) + assert result.status == ImportStatus.FAILED + assert result.error == "oops" + + +def test_confirm_import_expired_returns_failed(): + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "expired" in result.error + + +def test_confirm_import_invalid_pending_data_type_returns_failed(): + app_dsl_service.redis_client.get.return_value = 123 + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "Invalid import information" in result.error + + +def test_confirm_import_success_deletes_redis_key(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + session = MagicMock() + session.scalar.return_value = None + service = AppDslService(session) + + pending = PendingData( + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + name="name", + description="desc", + icon_type="emoji", + icon="🤖", + icon_background="#fff", + app_id=None, + ) + app_dsl_service.redis_client.get.return_value = pending.model_dump_json() + + created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.COMPLETED + assert result.app_id == "confirmed-app" + app_dsl_service.redis_client.delete.assert_called_once() + + +def test_confirm_import_exception_returns_failed(monkeypatch): + app_dsl_service.redis_client.get.return_value = "not-json" + monkeypatch.setattr( + PendingData, "model_validate_json", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad")) + ) + + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert result.error == "bad" + + +def test_check_dependencies_returns_empty_when_no_redis_data(): + service = AppDslService(MagicMock()) + result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + assert result.leaked_dependencies == [] + + +def test_check_dependencies_calls_analysis_service(monkeypatch): + pending = CheckDependenciesPendingData(dependencies=[], app_id="app-1").model_dump_json() + app_dsl_service.redis_client.get.return_value = pending + dep = app_dsl_service.PluginDependency.model_validate( + {"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}} + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [dep], + ) + + service = AppDslService(MagicMock()) + result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + assert len(result.leaked_dependencies) == 1 + + +def test_create_or_update_app_missing_mode_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="loss app mode"): + service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) + + +def test_create_or_update_app_existing_app_updates_fields(monkeypatch): + fixed_now = object() + monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + app = SimpleNamespace( + id="app-1", + tenant_id="tenant-1", + mode=AppMode.WORKFLOW.value, + name="old", + description="old-desc", + icon_type=IconType.EMOJI, + icon="old-icon", + icon_background="#111111", + updated_by=None, + updated_at=None, + app_model_config=None, + ) + service = AppDslService(MagicMock()) + updated = service._create_or_update_app( + app=app, + data={ + "app": {"mode": AppMode.WORKFLOW.value, "name": "yaml-name", "icon_type": IconType.IMAGE, "icon": "X"}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + }, + account=_account_mock(), + name="override-name", + description=None, + icon_background="#222222", + ) + assert updated is app + assert app.name == "override-name" + assert app.icon_type == IconType.IMAGE + assert app.icon == "X" + assert app.icon_background == "#222222" + assert app.updated_at is fixed_now + + +def test_create_or_update_app_new_app_requires_tenant(): + account = _account_mock() + account.current_tenant_id = None + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Current tenant is not set"): + service._create_or_update_app( + app=None, + data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}}, + account=account, + ) + + +def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(monkeypatch): + class DummyApp(SimpleNamespace): + pass + + monkeypatch.setattr(app_dsl_service, "App", DummyApp) + + sent: list[tuple[str, object]] = [] + monkeypatch.setattr(app_dsl_service.app_was_created, "send", lambda app, account: sent.append((app.id, account.id))) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = SimpleNamespace(unique_hash="uh") + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + monkeypatch.setattr( + AppDslService, "decrypt_dataset_id", lambda *_args, **_kwargs: "00000000-0000-0000-0000-000000000000" + ) + + session = MagicMock() + service = AppDslService(session) + deps = [ + app_dsl_service.PluginDependency.model_validate( + {"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}} + ) + ] + data = { + "app": {"mode": AppMode.WORKFLOW.value, "name": "n"}, + "workflow": { + "environment_variables": [{"x": 1}], + "conversation_variables": [{"y": 2}], + "graph": { + "nodes": [ + {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["enc-1", "enc-2"]}}, + ] + }, + "features": {}, + }, + } + + app = service._create_or_update_app(app=None, data=data, account=_account_mock(), dependencies=deps) + + assert app.tenant_id == "tenant-1" + assert sent == [(app.id, "account-1")] + app_dsl_service.redis_client.setex.assert_called() + workflow_service.sync_draft_workflow.assert_called_once() + + passed_graph = workflow_service.sync_draft_workflow.call_args.kwargs["graph"] + dataset_ids = passed_graph["nodes"][0]["data"]["dataset_ids"] + assert dataset_ids == ["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000000"] + + +def test_create_or_update_app_workflow_missing_workflow_data_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Missing workflow data"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.WORKFLOW.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.WORKFLOW.value}}, + account=_account_mock(), + ) + + +def test_create_or_update_app_chat_requires_model_config(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Missing model_config"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.CHAT.value}}, + account=_account_mock(), + ) + + +def test_create_or_update_app_chat_creates_model_config_and_sends_event(monkeypatch): + class DummyModelConfig(SimpleNamespace): + def from_model_config_dict(self, _cfg: dict): + return self + + monkeypatch.setattr(app_dsl_service, "AppModelConfig", DummyModelConfig) + + sent: list[str] = [] + monkeypatch.setattr( + app_dsl_service.app_model_config_was_updated, "send", lambda app, app_model_config: sent.append(app.id) + ) + + session = MagicMock() + service = AppDslService(session) + + app = SimpleNamespace( + id="app-1", + tenant_id="tenant-1", + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ) + service._create_or_update_app( + app=app, + data={"app": {"mode": AppMode.CHAT.value}, "model_config": {"model": {"provider": "openai"}}}, + account=_account_mock(), + ) + + assert app.app_model_config_id is not None + assert sent == ["app-1"] + session.add.assert_called() + + +def test_create_or_update_app_invalid_mode_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Invalid app mode"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.RAG_PIPELINE.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.RAG_PIPELINE.value}}, + account=_account_mock(), + ) + + +def test_export_dsl_delegates_by_mode(monkeypatch): + workflow_calls: list[bool] = [] + model_calls: list[bool] = [] + monkeypatch.setattr(AppDslService, "_append_workflow_export_data", lambda **_kwargs: workflow_calls.append(True)) + monkeypatch.setattr( + AppDslService, "_append_model_config_export_data", lambda *_args, **_kwargs: model_calls.append(True) + ) + + workflow_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id="tenant-1", + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + AppDslService.export_dsl(workflow_app) + assert workflow_calls == [True] + + chat_app = SimpleNamespace( + mode=AppMode.CHAT.value, + tenant_id="tenant-1", + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}), + ) + AppDslService.export_dsl(chat_app) + assert model_calls == [True] + + +def test_append_workflow_export_data_filters_and_overrides(monkeypatch): + workflow_dict = { + "graph": { + "nodes": [ + {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["d1", "d2"]}}, + {"data": {"type": NodeType.TOOL, "credential_id": "secret"}}, + { + "data": { + "type": NodeType.AGENT, + "agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}}, + } + }, + {"data": {"type": NodeType.TRIGGER_SCHEDULE.value, "config": {"x": 1}}}, + {"data": {"type": NodeType.TRIGGER_WEBHOOK.value, "webhook_url": "x", "webhook_debug_url": "y"}}, + {"data": {"type": NodeType.TRIGGER_PLUGIN.value, "subscription_id": "s"}}, + ] + } + } + + workflow = SimpleNamespace(to_dict=lambda *, include_secret: workflow_dict) + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + AppDslService, "encrypt_dataset_id", lambda *, dataset_id, tenant_id: f"enc:{tenant_id}:{dataset_id}" + ) + monkeypatch.setattr( + TriggerScheduleNode := app_dsl_service.TriggerScheduleNode, + "get_default_config", + lambda: {"config": {"default": True}}, + ) + monkeypatch.setattr(AppDslService, "_extract_dependencies_from_workflow", lambda *_args, **_kwargs: ["dep-1"]) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]}) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + export_data: dict = {} + AppDslService._append_workflow_export_data( + export_data=export_data, + app_model=SimpleNamespace(tenant_id="tenant-1"), + include_secret=False, + workflow_id=None, + ) + + nodes = export_data["workflow"]["graph"]["nodes"] + assert nodes[0]["data"]["dataset_ids"] == ["enc:tenant-1:d1", "enc:tenant-1:d2"] + assert "credential_id" not in nodes[1]["data"] + assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0] + assert nodes[3]["data"]["config"] == {"default": True} + assert nodes[4]["data"]["webhook_url"] == "" + assert nodes[4]["data"]["webhook_debug_url"] == "" + assert nodes[5]["data"]["subscription_id"] == "" + assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}] + + +def test_append_workflow_export_data_missing_workflow_raises(monkeypatch): + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + with pytest.raises(ValueError, match="Missing draft workflow configuration"): + AppDslService._append_workflow_export_data( + export_data={}, + app_model=SimpleNamespace(tenant_id="tenant-1"), + include_secret=False, + workflow_id=None, + ) + + +def test_append_model_config_export_data_filters_credential_id(monkeypatch): + monkeypatch.setattr(AppDslService, "_extract_dependencies_from_model_config", lambda *_args, **_kwargs: ["dep-1"]) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]}) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}}) + app_model = SimpleNamespace(tenant_id="tenant-1", app_model_config=app_model_config) + export_data: dict = {} + + AppDslService._append_model_config_export_data(export_data, app_model) + assert export_data["model_config"]["agent_mode"]["tools"] == [{}] + assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}] + + +def test_append_model_config_export_data_requires_app_config(): + with pytest.raises(ValueError, match="Missing app configuration"): + AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None)) + + +def test_extract_dependencies_from_workflow_graph_covers_all_node_types(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + + monkeypatch.setattr(app_dsl_service.ToolNodeData, "model_validate", lambda _d: SimpleNamespace(provider_id="p1")) + monkeypatch.setattr( + app_dsl_service.LLMNodeData, "model_validate", lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m1")) + ) + monkeypatch.setattr( + app_dsl_service.QuestionClassifierNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m2")), + ) + monkeypatch.setattr( + app_dsl_service.ParameterExtractorNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m3")), + ) + + def kr_validate(_d): + return SimpleNamespace( + retrieval_mode="multiple", + multiple_retrieval_config=SimpleNamespace( + reranking_mode="weighted_score", + weights=SimpleNamespace(vector_setting=SimpleNamespace(embedding_provider_name="m4")), + reranking_model=None, + ), + single_retrieval_config=None, + ) + + monkeypatch.setattr(app_dsl_service.KnowledgeRetrievalNodeData, "model_validate", kr_validate) + + graph = { + "nodes": [ + {"data": {"type": NodeType.TOOL}}, + {"data": {"type": NodeType.LLM}}, + {"data": {"type": NodeType.QUESTION_CLASSIFIER}}, + {"data": {"type": NodeType.PARAMETER_EXTRACTOR}}, + {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL}}, + {"data": {"type": "unknown"}}, + ] + } + + deps = AppDslService._extract_dependencies_from_workflow_graph(graph) + assert deps == ["tool:p1", "model:m1", "model:m2", "model:m3", "model:m4"] + + +def test_extract_dependencies_from_workflow_graph_handles_exceptions(monkeypatch): + monkeypatch.setattr( + app_dsl_service.ToolNodeData, "model_validate", lambda _d: (_ for _ in ()).throw(ValueError("bad")) + ) + deps = AppDslService._extract_dependencies_from_workflow_graph({"nodes": [{"data": {"type": NodeType.TOOL}}]}) + assert deps == [] + + +def test_extract_dependencies_from_model_config_parses_providers(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + + deps = AppDslService._extract_dependencies_from_model_config( + { + "model": {"provider": "p1"}, + "dataset_configs": { + "datasets": {"datasets": [{"reranking_model": {"reranking_provider_name": {"provider": "p2"}}}]} + }, + "agent_mode": {"tools": [{"provider_id": "t1"}]}, + } + ) + assert deps == ["model:p1", "model:p2", "tool:t1"] + + +def test_extract_dependencies_from_model_config_handles_exceptions(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda _p: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_model_config({"model": {"provider": "p1"}}) + assert deps == [] + + +def test_get_leaked_dependencies_empty_returns_empty(): + assert AppDslService.get_leaked_dependencies("tenant-1", []) == [] + + +def test_get_leaked_dependencies_delegates(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [SimpleNamespace(tenant_id=tenant_id, deps=dependencies)], + ) + res = AppDslService.get_leaked_dependencies("tenant-1", [SimpleNamespace(id="x")]) + assert len(res) == 1 + + +def test_encrypt_decrypt_dataset_id_respects_config(monkeypatch): + tenant_id = "tenant-1" + dataset_uuid = "00000000-0000-0000-0000-000000000000" + + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", False) + assert AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) == dataset_uuid + + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + encrypted = AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) + assert encrypted != dataset_uuid + assert base64.b64decode(encrypted.encode()) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=tenant_id) == dataset_uuid + + +def test_decrypt_dataset_id_returns_plain_uuid_unchanged(): + value = "00000000-0000-0000-0000-000000000000" + assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id="tenant-1") == value + + +def test_decrypt_dataset_id_returns_none_on_invalid_data(monkeypatch): + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id="tenant-1") is None + + +def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(monkeypatch): + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + encrypted = AppDslService.encrypt_dataset_id(dataset_id="not-a-uuid", tenant_id="tenant-1") + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id="tenant-1") is None + + +def test_is_valid_uuid_handles_bad_inputs(): + assert AppDslService._is_valid_uuid("00000000-0000-0000-0000-000000000000") is True + assert AppDslService._is_valid_uuid("nope") is False diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index 47b759bc7d..c2b430c551 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -1,14 +1,50 @@ +""" +Comprehensive unit tests for services.app_generate_service.AppGenerateService. + +Covers: + - _build_streaming_task_on_subscribe (streams / pubsub / exception / idempotency) + - generate (COMPLETION / AGENT_CHAT / CHAT / ADVANCED_CHAT / WORKFLOW / invalid mode, + streaming & blocking, billing, quota-refund-on-error, rate_limit.exit) + - _get_max_active_requests (all limit combos) + - generate_single_iteration (ADVANCED_CHAT / WORKFLOW / invalid mode) + - generate_single_loop (ADVANCED_CHAT / WORKFLOW / invalid mode) + - generate_more_like_this + - _get_workflow (debugger / non-debugger / specific id / invalid format / not found) + - get_response_generator (ended / non-ended workflow run) +""" + +import threading +import time +import uuid +from contextlib import contextmanager from unittest.mock import MagicMock -import services.app_generate_service as app_generate_service_module +import pytest + +import services.app_generate_service as ags_module +from core.app.entities.app_invoke_entities import InvokeFrom from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError +# --------------------------------------------------------------------------- +# Helpers / Fakes +# --------------------------------------------------------------------------- class _DummyRateLimit: + """Minimal stand-in for RateLimit that never touches Redis.""" + + _instance_dict: dict[str, "_DummyRateLimit"] = {} + + def __new__(cls, client_id: str, max_active_requests: int): + # avoid singleton caching across tests + instance = object.__new__(cls) + return instance + def __init__(self, client_id: str, max_active_requests: int) -> None: self.client_id = client_id self.max_active_requests = max_active_requests + self._exited: list[str] = [] @staticmethod def gen_request_key() -> str: @@ -18,101 +54,720 @@ class _DummyRateLimit: return request_id or "dummy-request-id" def exit(self, request_id: str) -> None: - return None + self._exited.append(request_id) def generate(self, generator, request_id: str): return generator -def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch): - monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) - mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) +def _make_app(mode: AppMode | str, *, max_active_requests: int = 0, is_agent: bool = False) -> MagicMock: + app = MagicMock() + app.mode = mode + app.id = "app-id" + app.tenant_id = "tenant-id" + app.max_active_requests = max_active_requests + app.is_agent = is_agent + return app - workflow = MagicMock() - workflow.id = "workflow-id" - workflow.created_by = "owner-id" - - mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) - - generator_spy = mocker.patch( - "services.app_generate_service.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) - - app_model = MagicMock() - app_model.mode = AppMode.WORKFLOW - app_model.id = "app-id" - app_model.tenant_id = "tenant-id" - app_model.max_active_requests = 0 - app_model.is_agent = False +def _make_user() -> MagicMock: user = MagicMock() user.id = "user-id" - - result = AppGenerateService.generate( - app_model=app_model, - user=user, - args={"inputs": {"k": "v"}}, - invoke_from=MagicMock(), - streaming=False, - ) - - assert result == {"result": "ok"} - - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert pause_state_config is not None - assert pause_state_config.state_owner_user_id == "owner-id" + return user -def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch): - """ - Regression test: ADVANCED_CHAT in blocking mode should return a plain dict - (non-streaming), and must not go through the async retrieve_events path. - Keeps behavior consistent with WORKFLOW blocking branch. - """ - # Disable billing and stub RateLimit to a no-op that just passes values through - monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) - mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) - - # Arrange a fake workflow and wire AppGenerateService._get_workflow to return it +def _make_workflow(*, workflow_id: str = "workflow-id", created_by: str = "owner-id") -> MagicMock: workflow = MagicMock() - workflow.id = "workflow-id" - mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + workflow.id = workflow_id + workflow.created_by = created_by + return workflow - # Spy on the streaming retrieval path to ensure it's NOT called - retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events") - # Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False - generate_spy = mocker.patch( - "services.app_generate_service.AdvancedChatAppGenerator.generate", - return_value={"result": "ok"}, - ) +@contextmanager +def _noop_rate_limit_context(rate_limit, request_id): + """Drop-in replacement for rate_limit_context that doesn't touch Redis.""" + yield - # Minimal app model for ADVANCED_CHAT - app_model = MagicMock() - app_model.mode = AppMode.ADVANCED_CHAT - app_model.id = "app-id" - app_model.tenant_id = "tenant-id" - app_model.max_active_requests = 0 - app_model.is_agent = False - user = MagicMock() - user.id = "user-id" +# --------------------------------------------------------------------------- +# _build_streaming_task_on_subscribe +# --------------------------------------------------------------------------- +class TestBuildStreamingTaskOnSubscribe: + """Tests for AppGenerateService._build_streaming_task_on_subscribe.""" - # Must include query and inputs for AdvancedChatAppGenerator - args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}} + def test_streams_mode_starts_immediately(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + # task started immediately during build + assert called == [1] + # calling the returned callback is idempotent + cb() + assert called == [1] # not called again - # Act: call service with streaming=False (blocking mode) - result = AppGenerateService.generate( - app_model=app_model, - user=user, - args=args, - invoke_from=MagicMock(), - streaming=False, - ) + def test_pubsub_mode_starts_on_subscribe(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) # large to prevent timer + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + assert called == [] + cb() + assert called == [1] + # second call is idempotent + cb() + assert called == [1] - # Assert: returns the dict from generate(), and did not call retrieve_events() - assert result == {"result": "ok"} - assert generate_spy.call_args.kwargs.get("streaming") is False - retrieve_spy.assert_not_called() + def test_sharded_mode_starts_on_subscribe(self, monkeypatch): + """sharded is treated like pubsub (i.e. not 'streams').""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + assert called == [] + cb() + assert called == [1] + + def test_pubsub_fallback_timer_fires(self, monkeypatch): + """When nobody subscribes fast enough the fallback timer fires.""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 50) # 50 ms + called = [] + _cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + time.sleep(0.2) # give the timer time to fire + assert called == [1] + + def test_exception_in_start_task_returns_false(self, monkeypatch): + """When start_task raises, _try_start returns False and next call retries.""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + call_count = 0 + + def _bad(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("boom") + + cb = AppGenerateService._build_streaming_task_on_subscribe(_bad) + # first call inside build raised, but is caught; second call via cb succeeds + assert call_count == 1 + cb() + assert call_count == 2 + + def test_concurrent_subscribe_only_starts_once(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) + call_count = 0 + + def _inc(): + nonlocal call_count + call_count += 1 + + cb = AppGenerateService._build_streaming_task_on_subscribe(_inc) + threads = [threading.Thread(target=cb) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# _get_max_active_requests +# --------------------------------------------------------------------------- +class TestGetMaxActiveRequests: + def test_both_zero_returns_zero(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 0 + + def test_app_limit_only(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=5) + assert AppGenerateService._get_max_active_requests(app) == 5 + + def test_config_limit_only(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 10) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 10 + + def test_both_non_zero_returns_min(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 20) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=5) + assert AppGenerateService._get_max_active_requests(app) == 5 + + def test_default_active_requests_used_when_app_has_none(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 15) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 15 + + +# --------------------------------------------------------------------------- +# generate – every AppMode branch +# --------------------------------------------------------------------------- +class TestGenerate: + """Tests for AppGenerateService.generate covering each mode.""" + + @pytest.fixture(autouse=True) + def _common(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + # Prevent AppExecutionParams.new from touching real models via isinstance + mocker.patch( + "services.app_generate_service.rate_limit_context", + _noop_rate_limit_context, + ) + + # -- COMPLETION --------------------------------------------------------- + def test_completion_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"result": "ok"}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + result = AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "ok"} + gen_spy.assert_called_once() + + # -- AGENT_CHAT via mode ------------------------------------------------ + def test_agent_chat_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.generate", + return_value={"result": "agent"}, + ) + mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + result = AppGenerateService.generate( + app_model=_make_app(AppMode.AGENT_CHAT), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "agent"} + gen_spy.assert_called_once() + + # -- AGENT_CHAT via is_agent flag (non-AGENT_CHAT mode) ----------------- + def test_agent_via_is_agent_flag(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.generate", + return_value={"result": "agent-via-flag"}, + ) + mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + app = _make_app(AppMode.CHAT, is_agent=True) + result = AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "agent-via-flag"} + gen_spy.assert_called_once() + + # -- CHAT --------------------------------------------------------------- + def test_chat_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.ChatAppGenerator.generate", + return_value={"result": "chat"}, + ) + mocker.patch( + "services.app_generate_service.ChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + app = _make_app(AppMode.CHAT, is_agent=False) + result = AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "chat"} + gen_spy.assert_called_once() + + # -- ADVANCED_CHAT blocking --------------------------------------------- + def test_advanced_chat_blocking(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + + retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events") + gen_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.generate", + return_value={"result": "advanced-blocking"}, + ) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.ADVANCED_CHAT), + user=_make_user(), + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "advanced-blocking"} + assert gen_spy.call_args.kwargs.get("streaming") is False + retrieve_spy.assert_not_called() + + # -- ADVANCED_CHAT streaming -------------------------------------------- + def test_advanced_chat_streaming(self, mocker, monkeypatch): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AppExecutionParams.new", + return_value=MagicMock(workflow_run_id="wfr-1", model_dump_json=MagicMock(return_value="{}")), + ) + delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay") + # Let _build_streaming_task_on_subscribe call the real on_subscribe + # so the inner closure (line 165) actually executes. + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.ADVANCED_CHAT), + user=_make_user(), + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + # In streaming mode it should go through retrieve_events, not generate + gen_instance.retrieve_events.assert_called_once() + # The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe + delay_spy.assert_called_once() + + # -- WORKFLOW blocking -------------------------------------------------- + def test_workflow_blocking(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + gen_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.generate", + return_value={"result": "workflow-blocking"}, + ) + mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.WORKFLOW), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "workflow-blocking"} + call_kwargs = gen_spy.call_args.kwargs + assert call_kwargs.get("pause_state_config") is not None + assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id" + + # -- WORKFLOW streaming ------------------------------------------------- + def test_workflow_streaming(self, mocker, monkeypatch): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AppExecutionParams.new", + return_value=MagicMock(workflow_run_id="wfr-2", model_dump_json=MagicMock(return_value="{}")), + ) + delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay") + # Let _build_streaming_task_on_subscribe invoke the real on_subscribe + # so the inner closure (line 216) actually executes. + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + retrieve_spy = mocker.patch( + "services.app_generate_service.MessageBasedAppGenerator.retrieve_events", + return_value=iter([]), + ) + mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.WORKFLOW), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + retrieve_spy.assert_called_once() + # The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe + delay_spy.assert_called_once() + + # -- Invalid mode ------------------------------------------------------- + def test_invalid_mode_raises(self, mocker): + app = _make_app("invalid-mode", is_agent=False) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + +# --------------------------------------------------------------------------- +# generate – billing / quota +# --------------------------------------------------------------------------- +class TestGenerateBilling: + @pytest.fixture(autouse=True) + def _common(self, mocker, monkeypatch): + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + mocker.patch( + "services.app_generate_service.rate_limit_context", + _noop_rate_limit_context, + ) + + def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + quota_charge = MagicMock() + consume_mock = mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + return_value=quota_charge, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"ok": True}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + consume_mock.assert_called_once_with("tenant-id") + + def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): + from services.errors.app import QuotaExceededError + from services.errors.llm import InvokeRateLimitError + + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), + ) + + with pytest.raises(InvokeRateLimitError): + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_exception_refunds_quota_and_exits_rate_limit(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + quota_charge = MagicMock() + mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + return_value=quota_charge, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + side_effect=RuntimeError("boom"), + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + with pytest.raises(RuntimeError, match="boom"): + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + quota_charge.refund.assert_called_once() + + def test_rate_limit_exit_called_in_finally_for_blocking(self, mocker, monkeypatch): + """For non-streaming (blocking) calls, rate_limit.exit should be called in finally.""" + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + + exit_calls: list[str] = [] + + class _TrackingRateLimit(_DummyRateLimit): + def exit(self, request_id: str) -> None: + exit_calls.append(request_id) + + mocker.patch("services.app_generate_service.RateLimit", _TrackingRateLimit) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"ok": True}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + # exit is called in finally block for non-streaming + assert len(exit_calls) >= 1 + + +# --------------------------------------------------------------------------- +# _get_workflow +# --------------------------------------------------------------------------- +class TestGetWorkflow: + def test_debugger_fetches_draft(self, mocker): + draft_wf = _make_workflow() + ws = MagicMock() + ws.get_draft_workflow.return_value = draft_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER) + assert result is draft_wf + ws.get_draft_workflow.assert_called_once() + + def test_debugger_raises_when_no_draft(self, mocker): + ws = MagicMock() + ws.get_draft_workflow.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(ValueError, match="Workflow not initialized"): + AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER) + + def test_non_debugger_fetches_published(self, mocker): + pub_wf = _make_workflow() + ws = MagicMock() + ws.get_published_workflow.return_value = pub_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API) + assert result is pub_wf + ws.get_published_workflow.assert_called_once() + + def test_non_debugger_raises_when_no_published(self, mocker): + ws = MagicMock() + ws.get_published_workflow.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(ValueError, match="Workflow not published"): + AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API) + + def test_specific_workflow_id_valid_uuid(self, mocker): + valid_uuid = str(uuid.uuid4()) + specific_wf = _make_workflow(workflow_id=valid_uuid) + ws = MagicMock() + ws.get_published_workflow_by_id.return_value = specific_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid + ) + assert result is specific_wf + ws.get_published_workflow_by_id.assert_called_once() + + def test_specific_workflow_id_invalid_uuid(self, mocker): + ws = MagicMock() + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(WorkflowIdFormatError): + AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id="not-a-uuid" + ) + + def test_specific_workflow_id_not_found(self, mocker): + valid_uuid = str(uuid.uuid4()) + ws = MagicMock() + ws.get_published_workflow_by_id.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(WorkflowNotFoundError): + AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid + ) + + +# --------------------------------------------------------------------------- +# generate_single_iteration +# --------------------------------------------------------------------------- +class TestGenerateSingleIteration: + def test_advanced_chat_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + gen_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + iter_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.single_iteration_generate", + return_value={"event": "iteration"}, + ) + app = _make_app(AppMode.ADVANCED_CHAT) + result = AppGenerateService.generate_single_iteration( + app_model=app, user=_make_user(), node_id="n1", args={"k": "v"} + ) + iter_spy.assert_called_once() + assert result == {"event": "iteration"} + + def test_workflow_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + iter_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.single_iteration_generate", + return_value={"event": "wf-iteration"}, + ) + app = _make_app(AppMode.WORKFLOW) + result = AppGenerateService.generate_single_iteration( + app_model=app, user=_make_user(), node_id="n1", args={"k": "v"} + ) + iter_spy.assert_called_once() + assert result == {"event": "wf-iteration"} + + def test_invalid_mode_raises(self, mocker): + app = _make_app(AppMode.CHAT) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate_single_iteration(app_model=app, user=_make_user(), node_id="n1", args={}) + + +# --------------------------------------------------------------------------- +# generate_single_loop +# --------------------------------------------------------------------------- +class TestGenerateSingleLoop: + def test_advanced_chat_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + loop_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.single_loop_generate", + return_value={"event": "loop"}, + ) + app = _make_app(AppMode.ADVANCED_CHAT) + result = AppGenerateService.generate_single_loop( + app_model=app, user=_make_user(), node_id="n1", args=MagicMock() + ) + loop_spy.assert_called_once() + assert result == {"event": "loop"} + + def test_workflow_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + loop_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.single_loop_generate", + return_value={"event": "wf-loop"}, + ) + app = _make_app(AppMode.WORKFLOW) + result = AppGenerateService.generate_single_loop( + app_model=app, user=_make_user(), node_id="n1", args=MagicMock() + ) + loop_spy.assert_called_once() + assert result == {"event": "wf-loop"} + + def test_invalid_mode_raises(self, mocker): + app = _make_app(AppMode.COMPLETION) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate_single_loop(app_model=app, user=_make_user(), node_id="n1", args=MagicMock()) + + +# --------------------------------------------------------------------------- +# generate_more_like_this +# --------------------------------------------------------------------------- +class TestGenerateMoreLikeThis: + def test_delegates_to_completion_generator(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate_more_like_this", + return_value={"result": "similar"}, + ) + result = AppGenerateService.generate_more_like_this( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + message_id="msg-1", + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + assert result == {"result": "similar"} + gen_spy.assert_called_once() + assert gen_spy.call_args.kwargs["stream"] is True + + +# --------------------------------------------------------------------------- +# get_response_generator +# --------------------------------------------------------------------------- +class TestGetResponseGenerator: + def test_non_ended_workflow_run(self, mocker): + app = _make_app(AppMode.ADVANCED_CHAT) + workflow_run = MagicMock() + workflow_run.id = "run-1" + workflow_run.status.is_ended.return_value = False + + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([{"event": "started"}]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run) + gen_instance.retrieve_events.assert_called_once() + + def test_ended_workflow_run_still_returns_generator(self, mocker): + """Even when the run is ended, the current code still returns a generator (TODO branch).""" + app = _make_app(AppMode.WORKFLOW) + workflow_run = MagicMock() + workflow_run.id = "run-2" + workflow_run.status.is_ended.return_value = True + + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run) + # current impl falls through the TODO and still creates a generator + gen_instance.retrieve_events.assert_called_once() diff --git a/api/tests/unit_tests/services/test_app_model_config_service.py b/api/tests/unit_tests/services/test_app_model_config_service.py new file mode 100644 index 0000000000..d4b4bf14a3 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_model_config_service.py @@ -0,0 +1,88 @@ +from unittest.mock import patch + +import pytest + +from models.model import AppMode +from services.app_model_config_service import AppModelConfigService + + +@pytest.fixture +def mock_config_managers(): + """Fixture that patches all app config manager validate methods. + + Returns a dictionary containing the mocked config_validate methods for each manager. + """ + with ( + patch("services.app_model_config_service.ChatAppConfigManager.config_validate") as mock_chat_validate, + patch("services.app_model_config_service.AgentChatAppConfigManager.config_validate") as mock_agent_validate, + patch( + "services.app_model_config_service.CompletionAppConfigManager.config_validate" + ) as mock_completion_validate, + ): + mock_chat_validate.return_value = {"manager": "chat"} + mock_agent_validate.return_value = {"manager": "agent"} + mock_completion_validate.return_value = {"manager": "completion"} + + yield { + "chat": mock_chat_validate, + "agent": mock_agent_validate, + "completion": mock_completion_validate, + } + + +class TestAppModelConfigService: + @pytest.mark.parametrize( + ("app_mode", "selected_manager"), + [ + (AppMode.CHAT, "chat"), + (AppMode.AGENT_CHAT, "agent"), + (AppMode.COMPLETION, "completion"), + ], + ) + def test_should_route_validation_to_correct_manager_based_on_app_mode( + self, app_mode, selected_manager, mock_config_managers + ): + """Test configuration validation is delegated to the expected manager for each supported app mode.""" + tenant_id = "tenant-123" + config = {"temperature": 0.5} + + mock_chat_validate = mock_config_managers["chat"] + mock_agent_validate = mock_config_managers["agent"] + mock_completion_validate = mock_config_managers["completion"] + + result = AppModelConfigService.validate_configuration(tenant_id=tenant_id, config=config, app_mode=app_mode) + + assert result == {"manager": selected_manager} + + if selected_manager == "chat": + mock_chat_validate.assert_called_once_with(tenant_id, config) + mock_agent_validate.assert_not_called() + mock_completion_validate.assert_not_called() + elif selected_manager == "agent": + mock_agent_validate.assert_called_once_with(tenant_id, config) + mock_chat_validate.assert_not_called() + mock_completion_validate.assert_not_called() + else: + mock_completion_validate.assert_called_once_with(tenant_id, config) + mock_chat_validate.assert_not_called() + mock_agent_validate.assert_not_called() + + def test_should_raise_value_error_when_app_mode_is_not_supported(self, mock_config_managers): + """Test unsupported app modes raise ValueError with the invalid mode in the message.""" + tenant_id = "tenant-123" + config = {"temperature": 0.5} + + mock_chat_validate = mock_config_managers["chat"] + mock_agent_validate = mock_config_managers["agent"] + mock_completion_validate = mock_config_managers["completion"] + + with pytest.raises(ValueError, match=f"Invalid app mode: {AppMode.WORKFLOW}"): + AppModelConfigService.validate_configuration( + tenant_id=tenant_id, + config=config, + app_mode=AppMode.WORKFLOW, + ) + + mock_chat_validate.assert_not_called() + mock_agent_validate.assert_not_called() + mock_completion_validate.assert_not_called() diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py new file mode 100644 index 0000000000..bff8dc92c6 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_service.py @@ -0,0 +1,609 @@ +"""Unit tests for services.app_service.""" + +import json +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock, patch + +import pytest + +from core.errors.error import ProviderTokenNotInitError +from models import Account, Tenant +from models.model import App, AppMode +from services.app_service import AppService + + +@pytest.fixture +def service() -> AppService: + """Provide AppService instance.""" + return AppService() + + +@pytest.fixture +def account() -> Account: + """Create account object for create_app tests.""" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + result = Account(name="Account User", email="account@example.com") + result.id = "acc-1" + result._current_tenant = tenant + return result + + +@pytest.fixture +def default_args() -> dict: + """Create default create_app args.""" + return { + "name": "Test App", + "mode": AppMode.CHAT.value, + "icon": "🤖", + "icon_background": "#FFFFFF", + } + + +@pytest.fixture +def app_template() -> dict: + """Create basic app template for create_app tests.""" + return { + AppMode.CHAT: { + "app": {}, + "model_config": { + "model": { + "provider": "provider-a", + "name": "model-a", + "mode": "chat", + "completion_params": {}, + } + }, + } + } + + +def _make_current_user() -> Account: + user = Account(name="Tester", email="tester@example.com") + user.id = "user-1" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + user._current_tenant = tenant + return user + + +class TestAppServicePagination: + """Test suite for get_paginate_apps.""" + + def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None: + """Test pagination returns None when tag filter has no targets.""" + # Arrange + args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]} + + with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]): + # Act + result = service.get_paginate_apps("user-1", "tenant-1", args) + + # Assert + assert result is None + + def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None: + """Test pagination delegates to db.paginate when filters are valid.""" + # Arrange + args = { + "mode": "workflow", + "is_created_by_me": True, + "name": "My_App%", + "tag_ids": ["tag-1"], + "page": 2, + "limit": 10, + } + expected_pagination = MagicMock() + + with ( + patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]), + patch("libs.helper.escape_like_pattern", return_value="escaped"), + patch("services.app_service.db") as mock_db, + ): + mock_db.paginate.return_value = expected_pagination + + # Act + result = service.get_paginate_apps("user-1", "tenant-1", args) + + # Assert + assert result is expected_pagination + mock_db.paginate.assert_called_once() + + +class TestAppServiceCreate: + """Test suite for create_app.""" + + def test_create_app_should_create_with_matching_default_model( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app uses matching default model and persists app config.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + model_instance = SimpleNamespace( + model_name="model-a", + provider="provider-a", + model_type_instance=MagicMock(), + credentials={"k": "v"}, + ) + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + patch("services.app_service.app_was_created") as mock_event, + patch("services.app_service.FeatureService.get_system_features") as mock_features, + patch("services.app_service.BillingService") as mock_billing, + patch("services.app_service.dify_config") as mock_config, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.return_value = model_instance + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_config.BILLING_ENABLED = True + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + assert app_instance.app_model_config_id == "cfg-1" + mock_db.session.add.assert_any_call(app_instance) + mock_db.session.add.assert_any_call(app_model_config) + assert mock_db.session.flush.call_count == 2 + mock_db.session.commit.assert_called_once() + mock_event.send.assert_called_once_with(app_instance, account=account) + mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") + + def test_create_app_should_raise_when_model_schema_missing( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app raises ValueError when non-matching model has no schema.""" + # Arrange + app_instance = SimpleNamespace(id="app-1") + model_instance = SimpleNamespace( + model_name="model-b", + provider="provider-b", + model_type_instance=MagicMock(), + credentials={"k": "v"}, + ) + model_instance.model_type_instance.get_model_schema.return_value = None + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.return_value = model_instance + + # Act & Assert + with pytest.raises(ValueError, match="model schema not found"): + service.create_app("tenant-1", default_args, account) + mock_db.session.commit.assert_not_called() + + def test_create_app_should_fallback_to_default_provider_when_model_missing( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app falls back to provider/model name when no default model instance is available.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + patch("services.app_service.app_was_created") as mock_event, + patch("services.app_service.FeatureService.get_system_features") as mock_features, + patch("services.app_service.EnterpriseService") as mock_enterprise, + patch("services.app_service.dify_config") as mock_config, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready") + manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_config.BILLING_ENABLED = False + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + mock_event.send.assert_called_once_with(app_instance, account=account) + mock_db.session.commit.assert_called_once() + mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private") + + def test_create_app_should_log_and_fallback_on_unexpected_model_error( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test unexpected model manager errors are logged and fallback provider is used.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db"), + patch("services.app_service.app_was_created"), + patch( + "services.app_service.FeatureService.get_system_features", + return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)), + ), + patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)), + patch("services.app_service.logger") as mock_logger, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.side_effect = RuntimeError("boom") + manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + mock_logger.exception.assert_called_once() + + +class TestAppServiceGetAndUpdate: + """Test suite for app retrieval and update methods.""" + + def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None: + """Test get_app returns original app for non-agent modes.""" + # Arrange + app = MagicMock() + app.mode = AppMode.CHAT + app.is_agent = False + + with patch("services.app_service.current_user", _make_current_user()): + # Act + result = service.get_app(app) + + # Assert + assert result is app + + def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None: + """Test get_app returns app when agent mode has no model config.""" + # Arrange + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = None + + with patch("services.app_service.current_user", _make_current_user()): + # Act + result = service.get_app(app) + + # Assert + assert result is app + + def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None: + """Test get_app decrypts and masks secret tool parameters.""" + # Arrange + tool = { + "provider_type": "builtin", + "provider_id": "provider-1", + "tool_name": "tool-a", + "tool_parameters": {"secret": "encrypted"}, + "extra": True, + } + model_config = MagicMock() + model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]} + + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = model_config + + manager = MagicMock() + manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"} + manager.mask_tool_parameters.return_value = {"secret": "***"} + + with ( + patch("services.app_service.current_user", _make_current_user()), + patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()), + patch("services.app_service.ToolParameterConfigurationManager", return_value=manager), + ): + # Act + result = service.get_app(app) + + # Assert + assert result.app_model_config is model_config + assert tool["tool_parameters"] == {"secret": "***"} + assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"} + + def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None: + """Test get_app logs and continues when masking fails.""" + # Arrange + tool = { + "provider_type": "builtin", + "provider_id": "provider-1", + "tool_name": "tool-a", + "tool_parameters": {"secret": "encrypted"}, + "extra": True, + } + model_config = MagicMock() + model_config.agent_mode_dict = {"tools": [tool]} + + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = model_config + + with ( + patch("services.app_service.current_user", _make_current_user()), + patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")), + patch("services.app_service.logger") as mock_logger, + ): + # Act + result = service.get_app(app) + + # Assert + assert result.app_model_config is model_config + mock_logger.exception.assert_called_once() + + def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None: + """Test update methods set fields and commit changes.""" + # Arrange + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type="emoji", + icon="a", + icon_background="#111", + enable_site=True, + enable_api=True, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": "image", + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + patch("services.app_service.naive_utc_now", return_value="now"), + ): + # Act + updated = service.update_app(app, args) + renamed = service.update_app_name(app, "rename") + iconed = service.update_app_icon(app, "icon-2", "#333") + site_same = service.update_app_site_status(app, app.enable_site) + api_same = service.update_app_api_status(app, app.enable_api) + site_changed = service.update_app_site_status(app, False) + api_changed = service.update_app_api_status(app, False) + + # Assert + assert updated is app + assert renamed is app + assert iconed is app + assert site_same is app + assert api_same is app + assert site_changed is app + assert api_changed is app + assert mock_db.session.commit.call_count >= 5 + + +class TestAppServiceDeleteAndMeta: + """Test suite for delete and metadata methods.""" + + def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None: + """Test delete_app removes app, runs cleanup, and triggers async deletion task.""" + # Arrange + app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + with ( + patch("services.app_service.db") as mock_db, + patch( + "services.app_service.FeatureService.get_system_features", + return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)), + ), + patch("services.app_service.EnterpriseService") as mock_enterprise, + patch( + "services.app_service.dify_config", + new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"), + ), + patch("services.app_service.BillingService") as mock_billing, + patch("services.app_service.remove_app_and_related_data_task") as mock_task, + ): + # Act + service.delete_app(app) + + # Assert + mock_db.session.delete.assert_called_once_with(app) + mock_db.session.commit.assert_called_once() + mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1") + mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") + mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1") + + def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None: + """Test get_app_meta extracts builtin and API tool icons from workflow graph.""" + # Arrange + workflow = SimpleNamespace( + graph_dict={ + "nodes": [ + { + "data": { + "type": "tool", + "provider_type": "builtin", + "provider_id": "builtin-provider", + "tool_name": "tool_builtin", + } + }, + { + "data": { + "type": "tool", + "provider_type": "api", + "provider_id": "api-provider-id", + "tool_name": "tool_api", + } + }, + ] + } + ) + app = cast( + App, + SimpleNamespace( + mode=AppMode.WORKFLOW.value, + workflow=workflow, + app_model_config=None, + tenant_id="tenant-1", + icon_type="emoji", + icon_background="#fff", + ), + ) + + provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"})) + + with ( + patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), + patch("services.app_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = provider + mock_db.session.query.return_value = query + + # Act + meta = service.get_app_meta(app) + + # Assert + assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon") + assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"} + + def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None: + """Test get_app_meta falls back to default icon when API provider lookup fails.""" + # Arrange + app_model_config = SimpleNamespace( + agent_mode_dict={ + "tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}] + } + ) + app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None)) + + with ( + patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), + patch("services.app_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act + meta = service.get_app_meta(app) + + # Assert + assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"} + + def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None: + """Test get_app_meta returns empty metadata when workflow/model config is absent.""" + # Arrange + workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None)) + chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None)) + + # Act + workflow_meta = service.get_app_meta(workflow_app) + chat_meta = service.get_app_meta(chat_app) + + # Assert + assert workflow_meta == {"tool_icons": {}} + assert chat_meta == {"tool_icons": {}} + + +class TestAppServiceCodeLookup: + """Test suite for app code lookup methods.""" + + def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None: + """Test get_app_code_by_id raises when site is missing.""" + # Arrange + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id("app-1") + + def test_get_app_code_by_id_should_return_code(self) -> None: + """Test get_app_code_by_id returns site code.""" + # Arrange + site = SimpleNamespace(code="code-1") + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = site + mock_db.session.query.return_value = query + + # Act + result = AppService.get_app_code_by_id("app-1") + + # Assert + assert result == "code-1" + + def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None: + """Test get_app_id_by_code raises when code does not exist.""" + # Arrange + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("missing") + + def test_get_app_id_by_code_should_return_app_id(self) -> None: + """Test get_app_id_by_code returns linked app id.""" + # Arrange + site = SimpleNamespace(app_id="app-1") + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = site + mock_db.session.query.return_value = query + + # Act + result = AppService.get_app_id_by_code("code-1") + + # Assert + assert result == "app-1" diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py new file mode 100644 index 0000000000..639e091041 --- /dev/null +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -0,0 +1,507 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import services.async_workflow_service as async_workflow_service_module +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus +from services.async_workflow_service import AsyncWorkflowService +from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.workflow.entities import AsyncTriggerResponse, TriggerData +from services.workflow.queue_dispatcher import QueuePriority + + +class AsyncWorkflowServiceTestDataFactory: + """Factory helpers for async workflow service unit tests.""" + + @staticmethod + def create_trigger_data( + app_id: str = "app-123", + tenant_id: str = "tenant-123", + workflow_id: str | None = "workflow-123", + root_node_id: str = "root-node-123", + ) -> TriggerData: + """Create valid trigger data for async workflow execution tests.""" + return TriggerData( + app_id=app_id, + tenant_id=tenant_id, + workflow_id=workflow_id, + root_node_id=root_node_id, + inputs={"name": "dify"}, + files=[], + trigger_type=AppTriggerType.UNKNOWN, + trigger_from=WorkflowRunTriggeredFrom.APP_RUN, + trigger_metadata=None, + ) + + @staticmethod + def create_trigger_log_with_data(trigger_data: TriggerData, retry_count: int = 0) -> MagicMock: + """Create a mock trigger log with serialized trigger data.""" + trigger_log = MagicMock() + trigger_log.id = "trigger-log-123" + trigger_log.trigger_data = trigger_data.model_dump_json() + trigger_log.retry_count = retry_count + trigger_log.error = "previous-error" + trigger_log.status = WorkflowTriggerStatus.FAILED + trigger_log.to_dict.return_value = {"id": trigger_log.id} + return trigger_log + + +class TestAsyncWorkflowService: + @pytest.fixture + def async_workflow_trigger_mocks(self): + """Shared fixture for async workflow trigger tests. + + Yields mocks for: + - repo: SQLAlchemyWorkflowTriggerLogRepository + - dispatcher_manager_class: QueueDispatcherManager class + - dispatcher: dispatcher instance + - quota_workflow: QuotaType.WORKFLOW + - get_workflow: AsyncWorkflowService._get_workflow method + - professional_task: execute_workflow_professional + - team_task: execute_workflow_team + - sandbox_task: execute_workflow_sandbox + """ + mock_repo = MagicMock() + + def _create_side_effect(new_log): + new_log.id = "trigger-log-123" + return new_log + + mock_repo.create.side_effect = _create_side_effect + + mock_dispatcher = MagicMock() + quota_workflow = MagicMock() + mock_get_workflow = MagicMock() + + mock_professional_task = MagicMock() + mock_team_task = MagicMock() + mock_sandbox_task = MagicMock() + + with ( + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + patch.object(async_workflow_service_module, "QueueDispatcherManager") as mock_dispatcher_manager_class, + patch.object(async_workflow_service_module, "WorkflowService"), + patch.object( + async_workflow_service_module.AsyncWorkflowService, + "_get_workflow", + ) as mock_get_workflow, + patch.object( + async_workflow_service_module, + "QuotaType", + new=SimpleNamespace(WORKFLOW=quota_workflow), + ), + patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, + patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, + patch.object(async_workflow_service_module, "execute_workflow_sandbox") as mock_sandbox_task, + ): + # Configure dispatcher_manager to return our mock_dispatcher + mock_dispatcher_manager_class.return_value.get_dispatcher.return_value = mock_dispatcher + + yield { + "repo": mock_repo, + "dispatcher_manager_class": mock_dispatcher_manager_class, + "dispatcher": mock_dispatcher, + "quota_workflow": quota_workflow, + "get_workflow": mock_get_workflow, + "professional_task": mock_professional_task, + "team_task": mock_team_task, + "sandbox_task": mock_sandbox_task, + } + + @pytest.mark.parametrize( + ("queue_name", "selected_task_attr"), + [ + (QueuePriority.PROFESSIONAL, "execute_workflow_professional"), + (QueuePriority.TEAM, "execute_workflow_team"), + (QueuePriority.SANDBOX, "execute_workflow_sandbox"), + ], + ) + def test_should_dispatch_to_matching_celery_task_when_triggering_workflow( + self, queue_name, selected_task_attr, async_workflow_trigger_mocks + ): + """Test queue-based task routing and successful async trigger response.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = queue_name + mocks["get_workflow"].return_value = workflow + + task_result = MagicMock() + task_result.id = "task-123" + mocks["professional_task"].delay.return_value = task_result + mocks["team_task"].delay.return_value = task_result + mocks["sandbox_task"].delay.return_value = task_result + + class DummyAccount: + def __init__(self, user_id: str): + self.id = user_id + + with patch.object(async_workflow_service_module, "Account", DummyAccount): + user = DummyAccount("account-123") + + # Act + result = AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data) + + # Assert + assert isinstance(result, AsyncTriggerResponse) + assert result.workflow_trigger_log_id == "trigger-log-123" + assert result.task_id == "task-123" + assert result.status == "queued" + assert result.queue == queue_name + + mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") + assert session.commit.call_count == 2 + + created_log = mocks["repo"].create.call_args[0][0] + assert created_log.status == WorkflowTriggerStatus.QUEUED + assert created_log.queue_name == queue_name + assert created_log.created_by_role == CreatorUserRole.ACCOUNT + assert created_log.created_by == "account-123" + assert created_log.trigger_data == trigger_data.model_dump_json() + assert created_log.inputs == json.dumps(dict(trigger_data.inputs)) + assert created_log.celery_task_id == "task-123" + + task_mocks = { + "execute_workflow_professional": mocks["professional_task"], + "execute_workflow_team": mocks["team_task"], + "execute_workflow_sandbox": mocks["sandbox_task"], + } + for task_attr, task_mock in task_mocks.items(): + if task_attr == selected_task_attr: + task_mock.delay.assert_called_once_with({"workflow_trigger_log_id": "trigger-log-123"}) + else: + task_mock.delay.assert_not_called() + + def test_should_set_end_user_role_when_triggered_by_end_user(self, async_workflow_trigger_mocks): + """Test that non-account users are tracked as END_USER in trigger logs.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = QueuePriority.SANDBOX + mocks["get_workflow"].return_value = workflow + + task_result = MagicMock(id="task-123") + mocks["sandbox_task"].delay.return_value = task_result + + user = SimpleNamespace(id="end-user-123") + + # Act + AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data) + + # Assert + created_log = mocks["repo"].create.call_args[0][0] + assert created_log.created_by_role == CreatorUserRole.END_USER + assert created_log.created_by == "end-user-123" + + def test_should_raise_workflow_not_found_when_app_does_not_exist(self): + """Test trigger failure when app lookup returns no result.""" + # Arrange + session = MagicMock() + session.scalar.return_value = None + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data(app_id="missing-app") + + with ( + patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository"), + patch.object(async_workflow_service_module, "QueueDispatcherManager"), + patch.object(async_workflow_service_module, "WorkflowService"), + ): + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="App not found: missing-app"): + AsyncWorkflowService.trigger_workflow_async( + session=session, + user=SimpleNamespace(id="user-123"), + trigger_data=trigger_data, + ) + + def test_should_mark_log_rate_limited_and_raise_when_quota_exceeded(self, async_workflow_trigger_mocks): + """Test quota-exceeded path updates trigger log and raises WorkflowQuotaLimitError.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM + mocks["get_workflow"].return_value = workflow + mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + feature="workflow", + tenant_id="tenant-123", + required=1, + ) + + # Act / Assert + with pytest.raises( + WorkflowQuotaLimitError, + match="Workflow execution quota limit reached for tenant tenant-123", + ): + AsyncWorkflowService.trigger_workflow_async( + session=session, + user=SimpleNamespace(id="user-123"), + trigger_data=trigger_data, + ) + + assert session.commit.call_count == 2 + updated_log = mocks["repo"].update.call_args[0][0] + assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED + assert "Quota limit reached" in updated_log.error + mocks["professional_task"].delay.assert_not_called() + mocks["team_task"].delay.assert_not_called() + mocks["sandbox_task"].delay.assert_not_called() + + def test_should_raise_when_reinvoke_target_log_does_not_exist(self): + """Test reinvoke_trigger error path when original trigger log is missing.""" + # Arrange + session = MagicMock() + repo = MagicMock() + repo.get_by_id.return_value = None + + with patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo): + # Act / Assert + with pytest.raises(ValueError, match="Trigger log not found: missing-log"): + AsyncWorkflowService.reinvoke_trigger( + session=session, + user=SimpleNamespace(id="user-123"), + workflow_trigger_log_id="missing-log", + ) + + def test_should_update_original_log_and_requeue_when_reinvoking(self): + """Test reinvoke flow updates original log state and triggers a new async run.""" + # Arrange + session = MagicMock() + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + trigger_log = AsyncWorkflowServiceTestDataFactory.create_trigger_log_with_data(trigger_data, retry_count=1) + repo = MagicMock() + repo.get_by_id.return_value = trigger_log + + expected_response = AsyncTriggerResponse( + workflow_trigger_log_id="new-trigger-log-456", + task_id="task-456", + status="queued", + queue=QueuePriority.TEAM, + ) + + with ( + patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo), + patch.object( + async_workflow_service_module.AsyncWorkflowService, + "trigger_workflow_async", + return_value=expected_response, + ) as mock_trigger_workflow_async, + ): + user = SimpleNamespace(id="user-123") + + # Act + response = AsyncWorkflowService.reinvoke_trigger( + session=session, + user=user, + workflow_trigger_log_id="trigger-log-123", + ) + + # Assert + assert response == expected_response + assert trigger_log.status == WorkflowTriggerStatus.RETRYING + assert trigger_log.retry_count == 2 + assert trigger_log.error is None + assert trigger_log.triggered_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + called_trigger_data = mock_trigger_workflow_async.call_args[0][2] + assert isinstance(called_trigger_data, TriggerData) + assert called_trigger_data.app_id == "app-123" + + @pytest.mark.parametrize( + ("repo_result", "expected"), + [ + (None, None), + (MagicMock(), {"id": "trigger-log-123"}), + ], + ) + def test_should_return_trigger_log_dict_or_none(self, repo_result, expected): + """Test get_trigger_log returns serialized log data or None.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + fake_engine = MagicMock() + mock_repo.get_by_id.return_value = repo_result + if repo_result: + repo_result.to_dict.return_value = expected + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)), + patch.object( + async_workflow_service_module, "Session", return_value=mock_session_context + ) as mock_session_class, + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_trigger_log("trigger-log-123", tenant_id="tenant-123") + + # Assert + assert result == expected + mock_session_class.assert_called_once_with(fake_engine) + mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123") + + def test_should_return_recent_logs_as_dict_list(self): + """Test get_recent_logs converts repository models into dictionaries.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + log1 = MagicMock() + log1.to_dict.return_value = {"id": "log-1"} + log2 = MagicMock() + log2.to_dict.return_value = {"id": "log-2"} + mock_repo.get_recent_logs.return_value = [log1, log2] + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), + patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_recent_logs( + tenant_id="tenant-123", + app_id="app-123", + hours=12, + limit=50, + offset=10, + ) + + # Assert + assert result == [{"id": "log-1"}, {"id": "log-2"}] + mock_repo.get_recent_logs.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + hours=12, + limit=50, + offset=10, + ) + + def test_should_return_failed_logs_for_retry_as_dict_list(self): + """Test get_failed_logs_for_retry serializes repository logs into dicts.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + log = MagicMock() + log.to_dict.return_value = {"id": "failed-log-1"} + mock_repo.get_failed_for_retry.return_value = [log] + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), + patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_failed_logs_for_retry(tenant_id="tenant-123", max_retry_count=4, limit=20) + + # Assert + assert result == [{"id": "failed-log-1"}] + mock_repo.get_failed_for_retry.assert_called_once_with(tenant_id="tenant-123", max_retry_count=4, limit=20) + + +class TestAsyncWorkflowServiceGetWorkflow: + def test_should_return_specific_workflow_when_workflow_id_exists(self): + """Test _get_workflow returns published workflow by id when provided.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + workflow = MagicMock() + workflow_service.get_published_workflow_by_id.return_value = workflow + + # Act + result = AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-123") + + # Assert + assert result == workflow + workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123") + workflow_service.get_published_workflow.assert_not_called() + + def test_should_raise_when_specific_workflow_id_not_found(self): + """Test _get_workflow raises WorkflowNotFoundError for unknown workflow id.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + workflow_service.get_published_workflow_by_id.return_value = None + + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="Published workflow not found: workflow-404"): + AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-404") + + def test_should_return_default_published_workflow_when_workflow_id_not_provided(self): + """Test _get_workflow returns default published workflow when no id is provided.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + workflow = MagicMock() + workflow_service.get_published_workflow.return_value = workflow + + # Act + result = AsyncWorkflowService._get_workflow(workflow_service, app_model) + + # Assert + assert result == workflow + workflow_service.get_published_workflow.assert_called_once_with(app_model) + workflow_service.get_published_workflow_by_id.assert_not_called() + + def test_should_raise_when_default_published_workflow_not_found(self): + """Test _get_workflow raises WorkflowNotFoundError when app has no published workflow.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + workflow_service.get_published_workflow.return_value = None + + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="No published workflow found for app: app-123"): + AsyncWorkflowService._get_workflow(workflow_service, app_model) diff --git a/api/tests/unit_tests/services/test_attachment_service.py b/api/tests/unit_tests/services/test_attachment_service.py new file mode 100644 index 0000000000..88be20bc41 --- /dev/null +++ b/api/tests/unit_tests/services/test_attachment_service.py @@ -0,0 +1,73 @@ +import base64 +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +import services.attachment_service as attachment_service_module +from models.model import UploadFile +from services.attachment_service import AttachmentService + + +class TestAttachmentService: + def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self): + """Test that AttachmentService keeps the provided sessionmaker instance.""" + session_factory = sessionmaker() + + service = AttachmentService(session_factory=session_factory) + + assert service._session_maker is session_factory + + def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self): + """Test that AttachmentService builds a sessionmaker bound to the provided engine.""" + engine = create_engine("sqlite:///:memory:") + + service = AttachmentService(session_factory=engine) + session = service._session_maker() + try: + assert session.bind == engine + finally: + session.close() + engine.dispose() + + @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) + def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory): + """Test that invalid session_factory types are rejected.""" + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + AttachmentService(session_factory=invalid_session_factory) + + def test_should_return_base64_encoded_blob_when_file_exists(self): + """Test that existing files are loaded from storage and returned as base64.""" + service = AttachmentService(session_factory=sessionmaker()) + upload_file = MagicMock(spec=UploadFile) + upload_file.key = "upload-file-key" + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = upload_file + service._session_maker = MagicMock(return_value=session) + + with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: + result = service.get_file_base64("file-123") + + assert result == base64.b64encode(b"binary-content").decode() + service._session_maker.assert_called_once_with(expire_on_commit=False) + session.query.assert_called_once_with(UploadFile) + mock_load.assert_called_once_with("upload-file-key") + + def test_should_raise_not_found_when_file_does_not_exist(self): + """Test that missing files raise NotFound and never call storage.""" + service = AttachmentService(session_factory=sessionmaker()) + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + service._session_maker = MagicMock(return_value=session) + + with patch.object(attachment_service_module.storage, "load_once") as mock_load: + with pytest.raises(NotFound, match="File not found"): + service.get_file_base64("missing-file") + + service._session_maker.assert_called_once_with(expire_on_commit=False) + session.query.assert_called_once_with(UploadFile) + mock_load.assert_not_called() diff --git a/api/tests/unit_tests/services/test_batch_indexing_base.py b/api/tests/unit_tests/services/test_batch_indexing_base.py new file mode 100644 index 0000000000..bd68b67d89 --- /dev/null +++ b/api/tests/unit_tests/services/test_batch_indexing_base.py @@ -0,0 +1,387 @@ +from dataclasses import asdict +from typing import Any, ClassVar, cast +from unittest.mock import MagicMock, patch + +import pytest + +from core.entities.document_task import DocumentTask +from enums.cloud_plan import CloudPlan +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy + +# --------------------------------------------------------------------------- +# Concrete subclass for testing (the base class is abstract) +# --------------------------------------------------------------------------- + + +class ConcreteBatchProxy(BatchDocumentIndexingProxy): + """Minimal concrete implementation that provides the required class-level vars.""" + + QUEUE_NAME: ClassVar[str] = "test_queue" + NORMAL_TASK_FUNC: ClassVar[Any] = MagicMock(name="NORMAL_TASK_FUNC") + PRIORITY_TASK_FUNC: ClassVar[Any] = MagicMock(name="PRIORITY_TASK_FUNC") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TENANT_ID = "tenant-abc" +DATASET_ID = "dataset-xyz" +DOC_IDS: list[str] = ["doc-1", "doc-2", "doc-3"] + + +def make_proxy(**kwargs: Any) -> ConcreteBatchProxy: + """Factory: returns a ConcreteBatchProxy with TenantIsolatedTaskQueue mocked out.""" + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue: + proxy = ConcreteBatchProxy( + tenant_id=kwargs.get("tenant_id", TENANT_ID), + dataset_id=kwargs.get("dataset_id", DATASET_ID), + document_ids=kwargs.get("document_ids", DOC_IDS), + ) + # Expose the mock queue on the proxy so tests can assert on it + proxy._tenant_isolated_task_queue = MockQueue.return_value + return proxy + + +# --------------------------------------------------------------------------- +# Test suite +# --------------------------------------------------------------------------- + + +class TestBatchDocumentIndexingProxyInit: + """Tests for __init__ of BatchDocumentIndexingProxy.""" + + def test_should_store_document_ids_when_initialized(self) -> None: + """Verify that document_ids are stored on the proxy instance.""" + # Arrange + doc_ids: list[str] = ["doc-a", "doc-b"] + + # Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids) + + # Assert + assert proxy._document_ids == doc_ids + + def test_should_propagate_tenant_and_dataset_to_base_when_initialized(self) -> None: + """Verify that tenant_id and dataset_id are forwarded to the parent class.""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS) + + # Assert + assert proxy._tenant_id == TENANT_ID + assert proxy._dataset_id == DATASET_ID + + def test_should_create_tenant_isolated_queue_with_correct_args_when_initialized(self) -> None: + """Verify that TenantIsolatedTaskQueue is constructed with (tenant_id, QUEUE_NAME).""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue: + ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS) + + # Assert + MockQueue.assert_called_once_with(TENANT_ID, ConcreteBatchProxy.QUEUE_NAME) + + @pytest.mark.parametrize("doc_ids", [[], ["single-doc"], ["d1", "d2", "d3", "d4"]]) + def test_should_accept_any_length_document_ids_when_initialized(self, doc_ids: list[str]) -> None: + """Verify that empty, single, and multiple document IDs are all accepted.""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids) + + # Assert + assert list(proxy._document_ids) == doc_ids + + +class TestSendToDirectQueue: + """Tests for _send_to_direct_queue.""" + + def test_should_call_task_func_delay_with_correct_args_when_sent_to_direct_queue( + self, + ) -> None: + """Verify that task_func.delay is called with the right kwargs.""" + # Arrange + proxy = make_proxy() + task_func = MagicMock() + + # Act + proxy._send_to_direct_queue(task_func) + + # Assert + task_func.delay.assert_called_once_with( + tenant_id=TENANT_ID, + dataset_id=DATASET_ID, + document_ids=DOC_IDS, + ) + + def test_should_not_interact_with_tenant_queue_when_sent_to_direct_queue(self) -> None: + """Direct queue path must never touch the tenant-isolated queue.""" + # Arrange + proxy = make_proxy() + task_func = MagicMock() + + # Act + proxy._send_to_direct_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.push_tasks.assert_not_called() + mock_queue.set_task_waiting_time.assert_not_called() + + def test_should_forward_any_callable_when_sent_to_direct_queue(self) -> None: + """Verify that different task functions are each called correctly.""" + # Arrange + proxy = make_proxy() + task_a, task_b = MagicMock(), MagicMock() + + # Act + proxy._send_to_direct_queue(task_a) + proxy._send_to_direct_queue(task_b) + + # Assert + task_a.delay.assert_called_once() + task_b.delay.assert_called_once() + + +class TestSendToTenantQueue: + """Tests for _send_to_tenant_queue — both branches.""" + + # ------------------------------------------------------------------ + # Branch 1: get_task_key() is truthy → push to waiting queue + # ------------------------------------------------------------------ + + def test_should_push_task_to_queue_when_task_key_exists(self) -> None: + """When get_task_key() is truthy, tasks must be pushed via push_tasks().""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + expected_payload = [asdict(DocumentTask(tenant_id=TENANT_ID, dataset_id=DATASET_ID, document_ids=DOC_IDS))] + mock_queue.push_tasks.assert_called_once_with(expected_payload) + + def test_should_not_call_task_func_delay_when_task_key_exists(self) -> None: + """When a key already exists, task_func.delay must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + cast(MagicMock, task_func.delay).assert_not_called() + + def test_should_not_set_waiting_time_when_task_key_exists(self) -> None: + """When a key already exists, set_task_waiting_time must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_not_called() + + def test_should_serialize_document_task_correctly_when_pushing_to_queue(self) -> None: + """Verify the serialised payload matches asdict(DocumentTask(...)).""" + # Arrange + proxy = make_proxy(document_ids=["doc-x"]) + proxy._tenant_isolated_task_queue.get_task_key.return_value = "k" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert — inspect the payload passed to push_tasks + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + call_args = mock_queue.push_tasks.call_args + pushed_list = call_args[0][0] # first positional arg + assert len(pushed_list) == 1 + assert pushed_list[0]["tenant_id"] == TENANT_ID + assert pushed_list[0]["dataset_id"] == DATASET_ID + assert pushed_list[0]["document_ids"] == ["doc-x"] + + # ------------------------------------------------------------------ + # Branch 2: get_task_key() is falsy → set flag + dispatch via delay + # ------------------------------------------------------------------ + + def test_should_set_waiting_time_and_call_delay_when_no_task_key(self) -> None: + """When get_task_key() is falsy, set_task_waiting_time and task_func.delay are invoked.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_called_once() + cast(MagicMock, task_func.delay).assert_called_once_with( + tenant_id=TENANT_ID, + dataset_id=DATASET_ID, + document_ids=DOC_IDS, + ) + + def test_should_not_push_tasks_when_no_task_key(self) -> None: + """When get_task_key() is falsy, push_tasks must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.push_tasks.assert_not_called() + + @pytest.mark.parametrize("falsy_key", [None, "", 0, False]) + def test_should_init_task_when_key_is_any_falsy_value(self, falsy_key: Any) -> None: + """Verify that any falsy return from get_task_key() triggers the init branch.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = falsy_key + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_called_once() + cast(MagicMock, task_func.delay).assert_called_once() + + +class TestDispatchRouting: + """Tests for the _dispatch / delay routing logic inherited from the base class.""" + + def _mock_features(self, enabled: bool, plan: CloudPlan) -> MagicMock: + features = MagicMock() + features.billing.enabled = enabled + features.billing.subscription.plan = plan + return features + + def test_should_send_to_normal_tenant_queue_when_billing_enabled_and_sandbox_plan(self) -> None: + """Sandbox plan routes to normal priority queue with tenant isolation.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.SANDBOX) + + # Act + with patch.object(proxy, "_send_to_default_tenant_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_send_to_priority_tenant_queue_when_billing_enabled_and_paid_plan(self) -> None: + """Non-sandbox paid plan routes to priority queue with tenant isolation.""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.PROFESSIONAL) + + # Act + with patch.object(proxy, "_send_to_priority_tenant_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_send_to_priority_direct_queue_when_billing_not_enabled(self) -> None: + """Self-hosted / no billing → priority direct queue (no tenant isolation).""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX) + + # Act + with patch.object(proxy, "_send_to_priority_direct_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_call_dispatch_when_delay_is_invoked(self) -> None: + """Calling delay() must invoke _dispatch() exactly once.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_dispatch") as mock_dispatch: + proxy.delay() + + # Assert + mock_dispatch.assert_called_once() + + def test_should_use_feature_service_for_billing_info(self) -> None: + """Verify that FeatureService.get_features is consulted during dispatch.""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX) + with patch.object(proxy, "_send_to_priority_direct_queue"): + # Act + proxy._dispatch() + + # Assert + mock_features.assert_called_once_with(TENANT_ID) + + +class TestBaseRouterHelpers: + """Tests for the three routing helper methods from the base class.""" + + def test_should_call_send_to_tenant_queue_with_normal_func_when_default_tenant_queue(self) -> None: + """_send_to_default_tenant_queue must forward NORMAL_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_tenant_queue") as mock_method: + proxy._send_to_default_tenant_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.NORMAL_TASK_FUNC) + + def test_should_call_send_to_tenant_queue_with_priority_func_when_priority_tenant_queue(self) -> None: + """_send_to_priority_tenant_queue must forward PRIORITY_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_tenant_queue") as mock_method: + proxy._send_to_priority_tenant_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC) + + def test_should_call_send_to_direct_queue_with_priority_func_when_priority_direct_queue(self) -> None: + """_send_to_priority_direct_queue must forward PRIORITY_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_direct_queue") as mock_method: + proxy._send_to_priority_direct_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC) diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index 5099362e00..3c0db51cd2 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -1,9 +1,12 @@ import datetime -from unittest.mock import Mock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.orm import Session +from enums.cloud_plan import CloudPlan +from services import clear_free_plan_tenant_expired_logs as service_module from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs @@ -156,13 +159,453 @@ class TestClearFreePlanTenantExpiredLogs: # Should call delete for each table that has records assert mock_session.query.return_value.where.return_value.delete.called - def test_clear_message_related_tables_logging_output( - self, mock_session, sample_message_ids, sample_records, capsys + def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes( + self, mock_session, sample_message_ids ): - """Test that logging output is generated.""" + record = Mock() + record.id = "record-1" + record.to_dict.side_effect = Exception("Serialization error") + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.query.return_value.where.return_value.all.return_value = [record] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) - pass + mock_storage.save.assert_not_called() + assert mock_session.query.return_value.where.return_value.delete.called + + +class _ImmediateFuture: + def __init__(self, fn, args, kwargs): + self._fn = fn + self._args = args + self._kwargs = kwargs + + def result(self): + return self._fn(*self._args, **self._kwargs) + + +class _ImmediateExecutor: + def __init__(self, *args, **kwargs) -> None: + self.submitted: list[tuple[object, tuple[object, ...], dict[str, object]]] = [] + + def submit(self, fn, *args, **kwargs): + self.submitted.append((fn, args, kwargs)) + return _ImmediateFuture(fn, args, kwargs) + + +def _session_wrapper_for_no_autoflush(session: Mock) -> Mock: + """ + ClearFreePlanTenantExpiredLogs.process_tenant uses: + with Session(db.engine).no_autoflush as session: + so Session(db.engine) must return an object with a no_autoflush context manager. + """ + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + + wrapper = MagicMock() + wrapper.no_autoflush = cm + return wrapper + + +def _session_wrapper_for_direct(session: Mock) -> Mock: + """ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:""" + wrapper = MagicMock() + wrapper.__enter__.return_value = session + wrapper.__exit__.return_value = None + return wrapper + + +def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -> None: + flask_app = service_module.Flask("test-app") + + monkeypatch.setattr( + service_module, + "db", + SimpleNamespace( + engine=object(), + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace( + all=lambda: [SimpleNamespace(id="app-1"), SimpleNamespace(id="app-2")] + ) + ), + ), + ) + + mock_storage = MagicMock() + monkeypatch.setattr(service_module, "storage", mock_storage) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + + clear_related = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", clear_related) + + # Session sequence for messages, conversations, workflow_app_logs loops: + # - messages: one batch then empty + # - conversations: one batch then empty + # - workflow app logs: one batch then empty + msg1 = SimpleNamespace(id="m1", to_dict=lambda: {"id": "m1"}) + conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"}) + log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"}) + + def make_query_with_batches(batches: list[list[object]]): + q = MagicMock() + q.where.return_value = q + q.limit.return_value = q + q.all.side_effect = batches + q.delete.return_value = 1 + return q + + msg_session_1 = MagicMock() + msg_session_1.query.side_effect = ( + lambda model: make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() + ) + msg_session_1.commit.return_value = None + + msg_session_2 = MagicMock() + msg_session_2.query.side_effect = ( + lambda model: make_query_with_batches([[]]) if model == service_module.Message else MagicMock() + ) + msg_session_2.commit.return_value = None + + conv_session_1 = MagicMock() + conv_session_1.query.side_effect = ( + lambda model: make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() + ) + conv_session_1.commit.return_value = None + + conv_session_2 = MagicMock() + conv_session_2.query.side_effect = ( + lambda model: make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() + ) + conv_session_2.commit.return_value = None + + wal_session_1 = MagicMock() + wal_session_1.query.side_effect = ( + lambda model: make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() + ) + wal_session_1.commit.return_value = None + + wal_session_2 = MagicMock() + wal_session_2.query.side_effect = ( + lambda model: make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() + ) + wal_session_2.commit.return_value = None + + session_wrappers = [ + _session_wrapper_for_no_autoflush(msg_session_1), + _session_wrapper_for_no_autoflush(msg_session_2), + _session_wrapper_for_no_autoflush(conv_session_1), + _session_wrapper_for_no_autoflush(conv_session_2), + _session_wrapper_for_no_autoflush(wal_session_1), + _session_wrapper_for_no_autoflush(wal_session_2), + ] + + monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_select(*_args, **_kwargs): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(service_module, "select", fake_select) + + # Repositories for workflow node executions and workflow runs + node_repo = MagicMock() + node_repo.get_expired_executions_batch.side_effect = [[SimpleNamespace(id="ne-1")], []] + node_repo.delete_executions_by_ids.return_value = 1 + + run_repo = MagicMock() + run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []] + run_repo.delete_runs_by_ids.return_value = 1 + + monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_node_execution_repository", + lambda _sm: node_repo, + ) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda _sm: run_repo, + ) + + ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=10) + + # messages backup, conversations backup, node executions backup, runs backup, workflow app logs backup + assert mock_storage.save.call_count >= 5 + clear_related.assert_called() + + +def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + + # Total tenant count query + count_session = MagicMock() + count_query = MagicMock() + count_query.count.return_value = 2 + count_session.query.return_value = count_query + + monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) + + # Avoid LocalProxy usage + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + echo_mock = MagicMock() + monkeypatch.setattr(service_module.click, "echo", echo_mock) + + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", True) + + def fake_get_info(tenant_id: str): + if tenant_id == "t_sandbox": + return {"subscription": {"plan": CloudPlan.SANDBOX}} + if tenant_id == "t_fail": + raise RuntimeError("boom") + return {"subscription": {"plan": "team"}} + + monkeypatch.setattr(service_module.BillingService, "get_info", staticmethod(fake_get_info)) + + process_tenant_mock = MagicMock(side_effect=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("err"))) + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + logger_exc = MagicMock() + monkeypatch.setattr(service_module.logger, "exception", logger_exc) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=["t_sandbox", "t_paid", "t_fail"]) + + # Only sandbox tenant should attempt processing, and its failure should be swallowed + logged. + assert process_tenant_mock.call_count == 1 + assert logger_exc.call_count >= 1 + + +def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) + fixed_now = started_at + datetime.timedelta(hours=2) + + class FixedDateTime(datetime.datetime): + @classmethod + def now(cls, tz=None): + return fixed_now + + monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime) + + # Avoid LocalProxy usage + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + + # Sessions used: + # 1) total tenant count + # 2) per-batch tenant scan (count + tenant list) + total_session = MagicMock() + total_query = MagicMock() + total_query.count.return_value = 250 + total_session.query.return_value = total_query + + batch_session = MagicMock() + q1 = MagicMock() + q1.where.return_value = q1 + q1.count.return_value = 200 + q2 = MagicMock() + q2.where.return_value = q2 + q2.count.return_value = 200 + q3 = MagicMock() + q3.where.return_value = q3 + q3.count.return_value = 200 + q4 = MagicMock() + q4.where.return_value = q4 + q4.count.return_value = 50 # choose this interval, then scale it + + rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")] + q_rs = MagicMock() + q_rs.where.return_value = q_rs + q_rs.order_by.return_value = rows + + batch_session.query.side_effect = [q1, q2, q3, q4, q_rs] + + sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] + monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) + + process_tenant_mock = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) + + # Should submit/process tenants from the batch query + assert process_tenant_mock.call_count == 2 + + +def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + + count_session = MagicMock() + count_query = MagicMock() + count_query.count.return_value = 100 + count_session.query.return_value = count_query + monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) + + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + echo_mock = MagicMock() + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", echo_mock) + + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", MagicMock()) + + tenant_ids = [f"t{i}" for i in range(100)] + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=tenant_ids) + + assert any("Processed 100 tenants" in str(call.args[0]) for call in echo_mock.call_args_list) + + +def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) + # Keep the total range smaller than the minimum interval (1 hour) so the loop runs once. + fixed_now = started_at + datetime.timedelta(minutes=30) + + class FixedDateTime(datetime.datetime): + @classmethod + def now(cls, tz=None): + return fixed_now + + monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime) + + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + + total_session = MagicMock() + total_query = MagicMock() + total_query.count.return_value = 250 + total_session.query.return_value = total_query + + batch_session = MagicMock() + # Count results for all 5 intervals, all > 100 => take the for-else path. + count_queries = [] + for _ in range(5): + q = MagicMock() + q.where.return_value = q + q.count.return_value = 200 + count_queries.append(q) + + rows = [SimpleNamespace(id="tenant-a")] + q_rs = MagicMock() + q_rs.where.return_value = q_rs + q_rs.order_by.return_value = rows + + batch_session.query.side_effect = [*count_queries, q_rs] + + sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] + monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) + + process_tenant_mock = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) + + assert process_tenant_mock.call_count == 1 + assert len(count_queries) == 5 + assert batch_session.query.call_count >= 6 + + +def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None: + flask_app = service_module.Flask("test-app") + + monkeypatch.setattr( + service_module, + "db", + SimpleNamespace( + engine=object(), + session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="app-1")])), + ), + ) + mock_storage = MagicMock() + monkeypatch.setattr(service_module, "storage", mock_storage) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", MagicMock()) + + # Make message/conversation/workflow_app_log loops no-op (empty immediately) + empty_session = MagicMock() + q_empty = MagicMock() + q_empty.where.return_value = q_empty + q_empty.limit.return_value = q_empty + q_empty.all.return_value = [] + empty_session.query.return_value = q_empty + empty_session.commit.return_value = None + session_wrappers = [ + _session_wrapper_for_no_autoflush(empty_session), + _session_wrapper_for_no_autoflush(empty_session), + _session_wrapper_for_no_autoflush(empty_session), + ] + monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_select(*_args, **_kwargs): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(service_module, "select", fake_select) + + # Repos: first returns exactly batch items -> no "< batch" break, second returns [] -> hit the len==0 break. + node_repo = MagicMock() + node_repo.get_expired_executions_batch.side_effect = [ + [SimpleNamespace(id="ne-1"), SimpleNamespace(id="ne-2")], + [], + ] + node_repo.delete_executions_by_ids.return_value = 2 + + run_repo = MagicMock() + run_repo.get_expired_runs_batch.side_effect = [ + [ + SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"}), + SimpleNamespace(id="wr-2", to_dict=lambda: {"id": "wr-2"}), + ], + [], + ] + run_repo.delete_runs_by_ids.return_value = 2 + + monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_node_execution_repository", + lambda _sm: node_repo, + ) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda _sm: run_repo, + ) + + ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=2) + + assert node_repo.get_expired_executions_batch.call_count == 2 + assert run_repo.get_expired_runs_batch.call_count == 2 diff --git a/api/tests/unit_tests/services/test_code_based_extension_service.py b/api/tests/unit_tests/services/test_code_based_extension_service.py new file mode 100644 index 0000000000..f6538a140a --- /dev/null +++ b/api/tests/unit_tests/services/test_code_based_extension_service.py @@ -0,0 +1,89 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from services.code_based_extension_service import CodeBasedExtensionService + + +class TestCodeBasedExtensionService: + def test_should_return_only_non_builtin_extensions_with_public_fields(self, monkeypatch: pytest.MonkeyPatch): + """Test service returns only non-builtin extensions with name/label/form_schema fields.""" + moderation_extension = SimpleNamespace( + name="custom-moderation", + label={"en-US": "Custom Moderation"}, + form_schema=[{"variable": "api_key"}], + builtin=False, + extension_class=object, + position=20, + ) + builtin_extension = SimpleNamespace( + name="builtin-moderation", + label={"en-US": "Builtin Moderation"}, + form_schema=[{"variable": "token"}], + builtin=True, + extension_class=object, + position=1, + ) + retrieval_extension = SimpleNamespace( + name="custom-retrieval", + label={"en-US": "Custom Retrieval"}, + form_schema=None, + builtin=False, + extension_class=object, + position=30, + ) + module_extensions_mock = MagicMock(return_value=[moderation_extension, builtin_extension, retrieval_extension]) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + result = CodeBasedExtensionService.get_code_based_extension("external_data_tool") + + assert result == [ + { + "name": "custom-moderation", + "label": {"en-US": "Custom Moderation"}, + "form_schema": [{"variable": "api_key"}], + }, + { + "name": "custom-retrieval", + "label": {"en-US": "Custom Retrieval"}, + "form_schema": None, + }, + ] + assert set(result[0].keys()) == {"name", "label", "form_schema"} + module_extensions_mock.assert_called_once_with("external_data_tool") + + def test_should_return_empty_list_when_all_extensions_are_builtin(self, monkeypatch: pytest.MonkeyPatch): + """Test builtin extensions are filtered out completely.""" + builtin_extension = SimpleNamespace( + name="builtin-moderation", + label={"en-US": "Builtin Moderation"}, + form_schema=[{"variable": "token"}], + builtin=True, + ) + module_extensions_mock = MagicMock(return_value=[builtin_extension]) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + result = CodeBasedExtensionService.get_code_based_extension("moderation") + + assert result == [] + module_extensions_mock.assert_called_once_with("moderation") + + def test_should_propagate_error_when_module_extensions_lookup_fails(self, monkeypatch: pytest.MonkeyPatch): + """Test ValueError from extension lookup bubbles up unchanged.""" + module_extensions_mock = MagicMock(side_effect=ValueError("Extension Module invalid-module not found")) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + with pytest.raises(ValueError, match="Extension Module invalid-module not found"): + CodeBasedExtensionService.get_code_based_extension("invalid-module") + + module_extensions_mock.assert_called_once_with("invalid-module") diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index d8ecdf45fd..75551531a2 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -1,18 +1,29 @@ """ Comprehensive unit tests for ConversationService. -This file keeps non-SQL guard/unit tests. -SQL-related tests were migrated to testcontainers integration tests. +This file provides complete test coverage for all ConversationService methods. +Tests are organized by functionality and include edge cases, error handling, +and both positive and negative test scenarios. """ -from datetime import datetime +from datetime import datetime, timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch +import pytest +from sqlalchemy import asc, desc + from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account -from models.model import App, Conversation, EndUser +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models import Account, ConversationVariable +from models.model import App, Conversation, EndUser, Message from services.conversation_service import ConversationService -from services.message_service import MessageService +from services.errors.conversation import ( + ConversationNotExistsError, + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) +from services.errors.message import MessageNotExistsError class ConversationServiceTestDataFactory: @@ -116,6 +127,84 @@ class ConversationServiceTestDataFactory: setattr(conversation, key, value) return conversation + @staticmethod + def create_message_mock( + message_id: str = "msg-123", + conversation_id: str = "conv-123", + app_id: str = "app-123", + **kwargs, + ) -> Mock: + """ + Create a mock Message object. + + Args: + message_id: Unique identifier for the message + conversation_id: Associated conversation identifier + app_id: Associated app identifier + **kwargs: Additional attributes to set on the mock + + Returns: + Mock Message object with specified attributes + """ + message = create_autospec(Message, instance=True) + message.id = message_id + message.conversation_id = conversation_id + message.app_id = app_id + message.query = kwargs.get("query", "Test message content") + message.created_at = kwargs.get("created_at", datetime.utcnow()) + for key, value in kwargs.items(): + setattr(message, key, value) + return message + + @staticmethod + def create_conversation_variable_mock( + variable_id: str = "var-123", + conversation_id: str = "conv-123", + app_id: str = "app-123", + **kwargs, + ) -> Mock: + """ + Create a mock ConversationVariable object. + + Args: + variable_id: Unique identifier for the variable + conversation_id: Associated conversation identifier + app_id: Associated app identifier + **kwargs: Additional attributes to set on the mock + + Returns: + Mock ConversationVariable object with specified attributes + """ + variable = create_autospec(ConversationVariable, instance=True) + variable.id = variable_id + variable.conversation_id = conversation_id + variable.app_id = app_id + variable.data = {"name": kwargs.get("name", "test_var"), "value": kwargs.get("value", "test_value")} + variable.created_at = kwargs.get("created_at", datetime.utcnow()) + variable.updated_at = kwargs.get("updated_at", datetime.utcnow()) + + # Mock to_variable method + mock_variable = Mock() + mock_variable.id = variable_id + mock_variable.name = kwargs.get("name", "test_var") + mock_variable.value_type = kwargs.get("value_type", "string") + mock_variable.value = kwargs.get("value", "test_value") + mock_variable.description = kwargs.get("description", "") + mock_variable.selector = kwargs.get("selector", {}) + mock_variable.model_dump.return_value = { + "id": variable_id, + "name": kwargs.get("name", "test_var"), + "value_type": kwargs.get("value_type", "string"), + "value": kwargs.get("value", "test_value"), + "description": kwargs.get("description", ""), + "selector": kwargs.get("selector", {}), + } + variable.to_variable.return_value = mock_variable + + for key, value in kwargs.items(): + setattr(variable, key, value) + return variable + class TestConversationServicePagination: """Test conversation pagination operations.""" @@ -175,99 +264,958 @@ class TestConversationServicePagination: assert result.limit == 20 -class TestConversationServiceMessageCreation: - """ - Test message creation and pagination. +class TestConversationServiceHelpers: + """Test helper methods in ConversationService.""" - Tests MessageService operations for creating and retrieving messages - within conversations. - """ - - def test_pagination_returns_empty_when_no_user(self): + def test_get_sort_params_with_descending_sort(self): """ - Test that pagination returns empty result when user is None. + Test _get_sort_params with descending sort prefix. - This ensures proper handling of unauthenticated requests. + When sort_by starts with '-', should return field name and desc function. + """ + # Act + field, direction = ConversationService._get_sort_params("-updated_at") + + # Assert + assert field == "updated_at" + assert direction == desc + + def test_get_sort_params_with_ascending_sort(self): + """ + Test _get_sort_params with ascending sort. + + When sort_by doesn't start with '-', should return field name and asc function. + """ + # Act + field, direction = ConversationService._get_sort_params("created_at") + + # Assert + assert field == "created_at" + assert direction == asc + + def test_build_filter_condition_with_descending_sort(self): + """ + Test _build_filter_condition with descending sort direction. + + Should create a less-than filter condition. """ # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() + mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_conversation.updated_at = datetime.utcnow() # Act - result = MessageService.pagination_by_first_id( - app_model=app_model, - user=None, - conversation_id="conv-123", - first_id=None, - limit=10, + condition = ConversationService._build_filter_condition( + sort_field="updated_at", + sort_direction=desc, + reference_conversation=mock_conversation, ) # Assert - assert result.data == [] - assert result.has_more is False + # The condition should be a comparison expression + assert condition is not None - def test_pagination_returns_empty_when_no_conversation_id(self): + def test_build_filter_condition_with_ascending_sort(self): """ - Test that pagination returns empty result when conversation_id is None. + Test _build_filter_condition with ascending sort direction. - This ensures proper handling of invalid requests. + Should create a greater-than filter condition. + """ + # Arrange + mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_conversation.created_at = datetime.utcnow() + + # Act + condition = ConversationService._build_filter_condition( + sort_field="created_at", + sort_direction=asc, + reference_conversation=mock_conversation, + ) + + # Assert + # The condition should be a comparison expression + assert condition is not None + + +class TestConversationServiceGetConversation: + """Test conversation retrieval operations.""" + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_account(self, mock_db_session): + """ + Test successful conversation retrieval with account user. + + Should return conversation when found with proper filters. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_account_id=user.id, from_source="console" + ) + + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + mock_db_session.query.assert_called_once_with(Conversation) + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_end_user(self, mock_db_session): + """ + Test successful conversation retrieval with end user. + + Should return conversation when found with proper filters for API user. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_end_user_id=user.id, from_source="api" + ) + + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + + @patch("services.conversation_service.db.session") + def test_get_conversation_not_found_raises_error(self, mock_db_session): + """ + Test that get_conversation raises error when conversation not found. + + Should raise ConversationNotExistsError when no matching conversation found. """ # Arrange app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() - # Act - result = MessageService.pagination_by_first_id( - app_model=app_model, - user=user, - conversation_id="", - first_id=None, - limit=10, - ) + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = None - # Assert - assert result.data == [] - assert result.has_more is False + # Act & Assert + with pytest.raises(ConversationNotExistsError): + ConversationService.get_conversation(app_model, "conv-123", user) -class TestConversationServiceSummarization: - """ - Test conversation summarization (auto-generated names). +class TestConversationServiceRename: + """Test conversation rename operations.""" - Tests the auto_generate_name functionality that creates conversation - titles based on the first message. - """ - - @patch("services.conversation_service.db.session", autospec=True) - @patch("services.conversation_service.ConversationService.get_conversation", autospec=True) - @patch("services.conversation_service.ConversationService.auto_generate_name", autospec=True) - def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_rename_with_manual_name(self, mock_get_conversation, mock_db_session): """ - Test renaming conversation with auto-generation enabled. + Test renaming conversation with manual name. - When auto_generate is True, the service should call the auto_generate_name - method to generate a new name for the conversation. + Should update conversation name and timestamp when auto_generate is False. """ # Arrange app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() conversation = ConversationServiceTestDataFactory.create_conversation_mock() - conversation.name = "Auto-generated Name" - # Mock the conversation lookup to return our test conversation mock_get_conversation.return_value = conversation - # Mock the auto_generate_name method to return the conversation + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id="conv-123", + user=user, + name="New Name", + auto_generate=False, + ) + + # Assert + assert result == conversation + assert conversation.name == "New Name" + mock_db_session.commit.assert_called_once() + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.ConversationService.auto_generate_name") + def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): + """ + Test renaming conversation with auto-generation. + + Should call auto_generate_name when auto_generate is True. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation mock_auto_generate.return_value = conversation # Act result = ConversationService.rename( app_model=app_model, - conversation_id=conversation.id, + conversation_id="conv-123", user=user, - name="", + name=None, auto_generate=True, ) # Assert - mock_auto_generate.assert_called_once_with(app_model, conversation) assert result == conversation + mock_auto_generate.assert_called_once_with(app_model, conversation) + + +class TestConversationServiceAutoGenerateName: + """Test conversation auto-name generation operations.""" + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_success(self, mock_llm_generator, mock_db_session): + """ + Test successful auto-generation of conversation name. + + Should generate name using LLMGenerator and update conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = message + + # Mock LLM generator + mock_llm_generator.generate_conversation_name.return_value = "Generated Name" + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + assert conversation.name == "Generated Name" + mock_llm_generator.generate_conversation_name.assert_called_once_with( + app_model.tenant_id, message.query, conversation.id, app_model.id + ) + mock_db_session.commit.assert_called_once() + + @patch("services.conversation_service.db.session") + def test_auto_generate_name_no_message_raises_error(self, mock_db_session): + """ + Test auto-generation fails when no message found. + + Should raise MessageNotExistsError when conversation has no messages. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + # Mock database query to return None + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = None + + # Act & Assert + with pytest.raises(MessageNotExistsError): + ConversationService.auto_generate_name(app_model, conversation) + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_handles_llm_exception(self, mock_llm_generator, mock_db_session): + """ + Test auto-generation handles LLM generator exceptions gracefully. + + Should continue without name when LLMGenerator fails. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = message + + # Mock LLM generator to raise exception + mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error") + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + # Name should remain unchanged due to exception + mock_db_session.commit.assert_called_once() + + +class TestConversationServiceDelete: + """Test conversation deletion operations.""" + + @patch("services.conversation_service.delete_conversation_related_data") + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_delete_success(self, mock_get_conversation, mock_db_session, mock_delete_task): + """ + Test successful conversation deletion. + + Should delete conversation and schedule cleanup task. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock(name="Test App") + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Act + ConversationService.delete(app_model, "conv-123", user) + + # Assert + mock_db_session.delete.assert_called_once_with(conversation) + mock_db_session.commit.assert_called_once() + mock_delete_task.delay.assert_called_once_with(conversation.id) + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session): + """ + Test deletion handles exceptions and rolls back transaction. + + Should rollback database changes when deletion fails. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_db_session.delete.side_effect = Exception("Database Error") + + # Act & Assert + with pytest.raises(Exception, match="Database Error"): + ConversationService.delete(app_model, "conv-123", user) + + # Assert rollback was called + mock_db_session.rollback.assert_called_once() + + +class TestConversationServiceConversationalVariable: + """Test conversational variable operations.""" + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_success(self, mock_get_conversation, mock_session_factory): + """ + Test successful retrieval of conversational variables. + + Should return paginated list of variables for conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + variable1 = ConversationServiceTestDataFactory.create_conversation_variable_mock() + variable2 = ConversationServiceTestDataFactory.create_conversation_variable_mock(variable_id="var-456") + + mock_session.scalars.return_value.all.return_value = [variable1, variable2] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 2 + assert result.limit == 10 + assert result.has_more is False + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_with_last_id(self, mock_get_conversation, mock_session_factory): + """ + Test retrieval of variables with last_id pagination. + + Should filter variables created after last_id. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( + created_at=datetime.utcnow() - timedelta(hours=1) + ) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=datetime.utcnow()) + + mock_session.scalar.return_value = last_variable + mock_session.scalars.return_value.all.return_value = [variable] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="var-123", + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + assert result.limit == 10 + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_last_id_not_found_raises_error( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that invalid last_id raises ConversationVariableNotExistsError. + + Should raise error when last_id doesn't exist. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="invalid-id", + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.dify_config") + def test_get_conversational_variable_with_name_filter_mysql( + self, mock_config, mock_get_conversation, mock_session_factory + ): + """ + Test variable filtering by name for MySQL databases. + + Should apply JSON extraction filter for variable names. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_config.DB_TYPE = "mysql" + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + variable_name="test_var", + ) + + # Assert - JSON filter should be applied + assert mock_session.scalars.called + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.dify_config") + def test_get_conversational_variable_with_name_filter_postgresql( + self, mock_config, mock_get_conversation, mock_session_factory + ): + """ + Test variable filtering by name for PostgreSQL databases. + + Should apply JSON extraction filter for variable names. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_config.DB_TYPE = "postgresql" + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + variable_name="test_var", + ) + + # Assert - JSON filter should be applied + assert mock_session.scalars.called + + +class TestConversationServiceUpdateVariable: + """Test conversation variable update operations.""" + + @patch("services.conversation_service.variable_factory") + @patch("services.conversation_service.ConversationVariableUpdater") + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_success( + self, mock_get_conversation, mock_session_factory, mock_updater_class, mock_variable_factory + ): + """ + Test successful update of conversation variable. + + Should update variable value and return updated data. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="string") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": "new_value"} + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="new_value", + ) + + # Assert + assert result["id"] == "var-123" + assert result["value"] == "new_value" + mock_updater.update.assert_called_once_with("conv-123", updated_variable) + mock_updater.flush.assert_called_once() + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_not_found_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when variable doesn't exist. + + Should raise ConversationVariableNotExistsError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="invalid-id", + user=user, + new_value="new_value", + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_type_mismatch_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when value type doesn't match expected type. + + Should raise ConversationVariableTypeMismatchError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="number") + mock_session.scalar.return_value = existing_variable + + # Act & Assert - Try to set string value for number variable + with pytest.raises(ConversationVariableTypeMismatchError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="string_value", # Wrong type + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_integer_number_compatibility( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that integer type accepts number values. + + Should allow number values for integer type variables. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="integer") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": 42} + + with ( + patch("services.conversation_service.variable_factory") as mock_variable_factory, + patch("services.conversation_service.ConversationVariableUpdater") as mock_updater_class, + ): + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value=42, # Number value for integer type + ) + + # Assert + assert result["value"] == 42 + mock_updater.update.assert_called_once() + + +class TestConversationServicePaginationAdvanced: + """Advanced pagination tests for ConversationService.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_last_id_not_found(self, mock_session_factory): + """ + Test pagination with invalid last_id raises error. + + Should raise LastConversationNotExistsError when last_id doesn't exist. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act & Assert + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id="invalid-id", + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_exclude_ids(self, mock_session_factory): + """ + Test pagination with exclude_ids filter. + + Should exclude specified conversation IDs from results. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=["excluded-123"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_has_more_detection(self, mock_session_factory): + """ + Test pagination has_more detection logic. + + Should set has_more=True when there are more results beyond limit. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + # Return exactly limit items to trigger has_more check + conversations = [ + ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=f"conv-{i}") for i in range(20) + ] + mock_session.scalars.return_value.all.return_value = conversations + mock_session.scalar.return_value = conversations[-1] + + # Mock count query to return > 0 + mock_session.scalar.return_value = 5 # Additional items exist + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is True + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_different_sort_by(self, mock_session_factory): + """ + Test pagination with different sort fields. + + Should handle various sort_by parameters correctly. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Test different sort fields + sort_fields = ["created_at", "-updated_at", "name", "-status"] + + for sort_by in sort_fields: + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + sort_by=sort_by, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + +class TestConversationServiceEdgeCases: + """Test edge cases and error scenarios.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_with_end_user_api_source(self, mock_session_factory): + """ + Test pagination correctly handles EndUser with API source. + + Should use 'api' as from_source for EndUser instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source="api", from_end_user_id="user-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + @patch("services.conversation_service.session_factory") + def test_pagination_with_account_console_source(self, mock_session_factory): + """ + Test pagination correctly handles Account with console source. + + Should use 'console' as from_source for Account instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source="console", from_account_id="account-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + def test_pagination_with_include_ids_filter(self): + """ + Test pagination with include_ids filter. + + Should only return conversations with IDs in include_ids list. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv-123", "conv-456"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + # Verify that include_ids filter was applied + assert mock_session.scalars.called + + def test_pagination_with_empty_exclude_ids(self): + """ + Test pagination with empty exclude_ids list. + + Should handle empty exclude_ids gracefully. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=[], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is False diff --git a/api/tests/unit_tests/services/test_conversation_variable_updater.py b/api/tests/unit_tests/services/test_conversation_variable_updater.py new file mode 100644 index 0000000000..20f7caa78e --- /dev/null +++ b/api/tests/unit_tests/services/test_conversation_variable_updater.py @@ -0,0 +1,75 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from dify_graph.variables import StringVariable +from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater + + +class TestConversationVariableUpdater: + def test_should_update_conversation_variable_data_and_commit(self): + """Test update persists serialized variable data when the row exists.""" + conversation_id = "conv-123" + variable = StringVariable( + id="var-123", + name="topic", + value="new value", + ) + expected_json = variable.model_dump_json() + + row = SimpleNamespace(data="old value") + session = MagicMock() + session.scalar.return_value = row + + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + session_maker = MagicMock(return_value=session_context) + updater = ConversationVariableUpdater(session_maker) + + updater.update(conversation_id=conversation_id, variable=variable) + + session_maker.assert_called_once_with() + session.scalar.assert_called_once() + stmt = session.scalar.call_args.args[0] + compiled_params = stmt.compile().params + assert variable.id in compiled_params.values() + assert conversation_id in compiled_params.values() + assert row.data == expected_json + session.commit.assert_called_once() + + def test_should_raise_not_found_error_when_conversation_variable_missing(self): + """Test update raises ConversationVariableNotFoundError when no matching row exists.""" + conversation_id = "conv-404" + variable = StringVariable( + id="var-404", + name="topic", + value="value", + ) + + session = MagicMock() + session.scalar.return_value = None + + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + session_maker = MagicMock(return_value=session_context) + updater = ConversationVariableUpdater(session_maker) + + with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): + updater.update(conversation_id=conversation_id, variable=variable) + + session.commit.assert_not_called() + + def test_should_do_nothing_when_flush_is_called(self): + """Test flush currently behaves as a no-op and returns None.""" + session_maker = MagicMock() + updater = ConversationVariableUpdater(session_maker) + + result = updater.flush() + + assert result is None + session_maker.assert_not_called() diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py new file mode 100644 index 0000000000..9ef314cb9e --- /dev/null +++ b/api/tests/unit_tests/services/test_credit_pool_service.py @@ -0,0 +1,157 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import services.credit_pool_service as credit_pool_service_module +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from services.credit_pool_service import CreditPoolService + + +@pytest.fixture +def mock_credit_deduction_setup(): + """Fixture providing common setup for credit deduction tests.""" + pool = SimpleNamespace(remaining_credits=50) + fake_engine = MagicMock() + session = MagicMock() + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool) + mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine)) + mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context) + + return { + "pool": pool, + "fake_engine": fake_engine, + "session": session, + "session_context": session_context, + "patches": (mock_get_pool, mock_db, mock_session), + } + + +class TestCreditPoolService: + def test_should_create_default_pool_with_trial_type_and_configured_quota(self): + """Test create_default_pool persists a trial pool using configured hosted credits.""" + tenant_id = "tenant-123" + hosted_pool_credits = 5000 + + with ( + patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits), + patch.object(credit_pool_service_module, "db") as mock_db, + ): + pool = CreditPoolService.create_default_pool(tenant_id) + + assert isinstance(pool, TenantCreditPool) + assert pool.tenant_id == tenant_id + assert pool.pool_type == "trial" + assert pool.quota_limit == hosted_pool_credits + assert pool.quota_used == 0 + mock_db.session.add.assert_called_once_with(pool) + mock_db.session.commit.assert_called_once() + + def test_should_return_first_pool_from_query_when_get_pool_called(self): + """Test get_pool queries by tenant and pool_type and returns first result.""" + tenant_id = "tenant-123" + pool_type = "enterprise" + expected_pool = MagicMock(spec=TenantCreditPool) + + with patch.object(credit_pool_service_module, "db") as mock_db: + query = mock_db.session.query.return_value + filtered_query = query.filter_by.return_value + filtered_query.first.return_value = expected_pool + + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type) + + assert result == expected_pool + mock_db.session.query.assert_called_once_with(TenantCreditPool) + query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type) + filtered_query.first.assert_called_once() + + def test_should_return_false_when_pool_not_found_in_check_credits_available(self): + """Test check_credits_available returns False when tenant has no pool.""" + with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool: + result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10) + + assert result is False + mock_get_pool.assert_called_once_with("tenant-123", "trial") + + def test_should_return_true_when_remaining_credits_cover_required_amount(self): + """Test check_credits_available returns True when remaining credits are sufficient.""" + pool = SimpleNamespace(remaining_credits=100) + + with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool: + result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) + + assert result is True + mock_get_pool.assert_called_once_with("tenant-123", "trial") + + def test_should_return_false_when_remaining_credits_are_insufficient(self): + """Test check_credits_available returns False when required credits exceed remaining credits.""" + pool = SimpleNamespace(remaining_credits=30) + + with patch.object(CreditPoolService, "get_pool", return_value=pool): + result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60) + + assert result is False + + def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self): + """Test check_and_deduct_credits raises when tenant credit pool does not exist.""" + with patch.object(CreditPoolService, "get_pool", return_value=None): + with pytest.raises(QuotaExceededError, match="Credit pool not found"): + CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) + + def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self): + """Test check_and_deduct_credits raises when remaining credits are zero or negative.""" + pool = SimpleNamespace(remaining_credits=0) + + with patch.object(CreditPoolService, "get_pool", return_value=pool): + with pytest.raises(QuotaExceededError, match="No credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) + + def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup): + """Test check_and_deduct_credits updates quota_used by the actual deducted amount.""" + tenant_id = "tenant-123" + pool_type = "trial" + credits_required = 200 + remaining_credits = 120 + expected_deducted_credits = 120 + + mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits + patches = mock_credit_deduction_setup["patches"] + session = mock_credit_deduction_setup["session"] + + with patches[0], patches[1], patches[2]: + result = CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=credits_required, + pool_type=pool_type, + ) + + assert result == expected_deducted_credits + session.execute.assert_called_once() + session.commit.assert_called_once() + + stmt = session.execute.call_args.args[0] + compiled_params = stmt.compile().params + assert tenant_id in compiled_params.values() + assert pool_type in compiled_params.values() + assert expected_deducted_credits in compiled_params.values() + + def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup): + """Test check_and_deduct_credits translates DB update failures to QuotaExceededError.""" + mock_credit_deduction_setup["pool"].remaining_credits = 50 + mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure") + session = mock_credit_deduction_setup["session"] + + patches = mock_credit_deduction_setup["patches"] + mock_logger = patch.object(credit_pool_service_module, "logger") + + with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj: + with pytest.raises(QuotaExceededError, match="Failed to deduct credits"): + CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10) + + session.commit.assert_not_called() + mock_logger_obj.exception.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py b/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py deleted file mode 100644 index cc718c9997..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py +++ /dev/null @@ -1,216 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import Dataset -from services.dataset_service import DatasetService - - -class DatasetDeleteTestDataFactory: - """Factory class for creating test data and mock objects for dataset delete tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "test-tenant-123", - created_by: str = "creator-456", - doc_form: str | None = None, - indexing_technique: str | None = "high_quality", - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.doc_form = doc_form - dataset.indexing_technique = indexing_technique - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "test-tenant-123", - role: TenantAccountRole = TenantAccountRole.ADMIN, - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - -class TestDatasetServiceDeleteDataset: - """ - Comprehensive unit tests for DatasetService.delete_dataset method. - - This test suite covers all deletion scenarios including: - - Normal dataset deletion with documents - - Empty dataset deletion (no documents, doc_form is None) - - Dataset deletion with missing indexing_technique - - Permission checks - - Event handling - - This test suite provides regression protection for issue #27073. - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted, - ): - yield { - "get_dataset": mock_get_dataset, - "check_permission": mock_check_perm, - "db_session": mock_db, - "dataset_was_deleted": mock_dataset_was_deleted, - } - - def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies): - """ - Test successful deletion of a dataset with documents. - - This test verifies: - - Dataset is retrieved correctly - - Permission check is performed - - dataset_was_deleted event is sent - - Dataset is deleted from database - - Method returns True - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" - ) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies): - """ - Test successful deletion of an empty dataset (no documents, doc_form is None). - - This test verifies that: - - Empty datasets can be deleted without errors - - dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None) - - Dataset is deleted from database - - Method returns True - - This is the primary test for issue #27073 where deleting an empty dataset - caused internal server error due to assertion failure in event handlers. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies): - """ - Test deletion of dataset with partial None values. - - This test verifies that datasets with partial None values (e.g., doc_form exists - but indexing_technique is None) can be deleted successfully. The event handler - will skip cleanup if any required field is None. - - Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions - to verify all core deletion operations are performed, not just event sending. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow (Gemini suggestion implemented) - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies): - """ - Test deletion of dataset where doc_form is None but indexing_technique exists. - - This edge case can occur in certain dataset configurations and should be handled - gracefully by the event handler's conditional check. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality") - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_not_found(self, mock_dataset_service_dependencies): - """ - Test deletion attempt when dataset doesn't exist. - - This test verifies that: - - Method returns False when dataset is not found - - No deletion operations are performed - - No events are sent - """ - # Arrange - dataset_id = "non-existent-dataset" - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = None - - # Act - result = DatasetService.delete_dataset(dataset_id, user) - - # Assert - assert result is False - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id) - mock_dataset_service_dependencies["check_permission"].assert_not_called() - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called() - mock_dataset_service_dependencies["db_session"].delete.assert_not_called() - mock_dataset_service_dependencies["db_session"].commit.assert_not_called() diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py new file mode 100644 index 0000000000..105ef7ba48 --- /dev/null +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -0,0 +1,760 @@ +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.plugin.entities.plugin_daemon import CredentialType +from dify_graph.model_runtime.entities.provider_entities import FormType +from models.account import Account +from models.model import EndUser +from models.oauth import DatasourceProvider +from models.provider_ids import DatasourceProviderID +from services.datasource_provider_service import DatasourceProviderService, get_current_user + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_id(s: str = "org/plugin/provider") -> DatasourceProviderID: + return DatasourceProviderID(s) + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + + +class TestDatasourceProviderService: + """Comprehensive tests for DatasourceProviderService targeting >95% coverage.""" + + @pytest.fixture + def service(self): + return DatasourceProviderService() + + @pytest.fixture + def mock_db_session(self): + """ + Robust, chainable query mock. + q returns itself for .filter_by(), .order_by(), .where() so any + SQLAlchemy chaining pattern works without multiple brittle sub-mocks. + """ + with patch("services.datasource_provider_service.Session") as mock_cls: + sess = MagicMock(spec=Session) + + q = MagicMock() + sess.query.return_value = q + + # Self-returning chain — any method called on q returns q + q.filter_by.return_value = q + q.order_by.return_value = q + q.where.return_value = q + + # Default terminal values (tests override per-case) + q.first.return_value = None + q.all.return_value = [] + q.count.return_value = 0 + q.delete.return_value = 1 + + mock_cls.return_value.__enter__.return_value = sess + mock_cls.return_value.no_autoflush.__enter__.return_value = sess + + yield sess + + @pytest.fixture(autouse=True) + def patch_db(self, mock_db_session): + with patch("services.datasource_provider_service.db") as mock_db: + mock_db.session = mock_db_session + mock_db.engine = MagicMock() + yield mock_db + + @pytest.fixture(autouse=True) + def patch_externals(self): + with ( + patch("httpx.request") as mock_httpx, + patch("services.datasource_provider_service.dify_config") as mock_cfg, + patch("services.datasource_provider_service.encrypter") as mock_enc, + patch("services.datasource_provider_service.redis_client") as mock_redis, + patch("services.datasource_provider_service.generate_incremental_name") as mock_genname, + patch("services.datasource_provider_service.OAuthHandler") as mock_oauth, + ): + mock_cfg.CONSOLE_API_URL = "http://localhost" + mock_enc.encrypt_token.return_value = "enc_tok" + mock_enc.decrypt_token.return_value = "dec_tok" + mock_enc.decrypt.return_value = {"k": "dec"} + mock_enc.encrypt.return_value = {"k": "enc"} + mock_enc.obfuscated_token.return_value = "obf" + mock_enc.mask_plugin_credentials.return_value = {"k": "mask"} + + mock_redis.lock.return_value.__enter__.return_value = MagicMock() + mock_genname.return_value = "gen_name" + + mock_oauth.return_value.refresh_credentials.return_value = MagicMock( + credentials={"k": "v"}, expires_at=9999 + ) + + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = { + "code": 0, + "message": "ok", + "data": { + "provider": "prov", + "plugin_unique_identifier": "pui", + "plugin_id": "org/plug", + "is_authorized": False, + "declaration": { + "identity": { + "author": "a", + "name": "n", + "description": {"en_US": "d"}, + "icon": "i", + "label": {"en_US": "l"}, + }, + "credentials_schema": [], + "oauth_schema": {"credentials_schema": [], "client_schema": []}, + "provider_type": "local_file", + "datasources": [], + }, + }, + } + mock_httpx.return_value = resp + + # Store handles for assertions + self._enc = mock_enc + self._redis = mock_redis + yield + + @pytest.fixture + def mock_user(self): + u = MagicMock() + u.id = "uid-1" + return u + + # ----------------------------------------------------------------------- + # get_current_user (lines 27-40) + # ----------------------------------------------------------------------- + + def test_should_return_proxy_when_current_object_is_account(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + user_obj = MagicMock() + user_obj.__class__ = Account + proxy._get_current_object.return_value = user_obj + assert get_current_user() is proxy + + def test_should_return_proxy_when_current_object_is_enduser(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + user_obj = MagicMock() + user_obj.__class__ = EndUser + proxy._get_current_object.return_value = user_obj + assert get_current_user() is proxy + + def test_should_return_proxy_when_get_current_object_raises_attribute_error(self): + """AttributeError from LocalProxy falls back to the proxy itself.""" + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + proxy._get_current_object.side_effect = AttributeError("no attr") + proxy.__class__ = Account # make the proxy itself satisfy isinstance + assert get_current_user() is proxy + + def test_should_raise_type_error_when_user_is_not_account_or_enduser(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + proxy._get_current_object.return_value = "plain_string" + with pytest.raises(TypeError, match="current_user must be Account or EndUser"): + get_current_user() + + # ----------------------------------------------------------------------- + # is_system_oauth_params_exist (line 357-363) + # ----------------------------------------------------------------------- + + def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session): + mock_db_session.query().first.return_value = MagicMock() + assert service.is_system_oauth_params_exist(make_id()) is True + + def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + assert service.is_system_oauth_params_exist(make_id()) is False + + # ----------------------------------------------------------------------- + # is_tenant_oauth_params_enabled (lines 365-379) + # NOTE: uses .count() not .first() + # ----------------------------------------------------------------------- + + def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session): + mock_db_session.query().count.return_value = 1 + assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True + + def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False + + # ----------------------------------------------------------------------- + # remove_oauth_custom_client_params (lines 55-61) + # ----------------------------------------------------------------------- + + def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session): + service.remove_oauth_custom_client_params("t1", make_id()) + mock_db_session.query().delete.assert_called_once() + + # ----------------------------------------------------------------------- + # setup_oauth_custom_client_params (315-351) + # ----------------------------------------------------------------------- + + def test_should_skip_db_write_when_credentials_are_none(self, service, mock_db_session): + """When credentials=None, should return immediately without any DB write.""" + service.setup_oauth_custom_client_params("t1", make_id(), None, None) + mock_db_session.add.assert_not_called() + + def test_should_create_new_config_when_none_exists(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True) + mock_db_session.add.assert_called_once() + + def test_should_update_existing_config_when_record_found(self, service, mock_db_session): + existing = MagicMock() + mock_db_session.query().first.return_value = existing + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False) + mock_db_session.add.assert_not_called() # update in place, no add + + # ----------------------------------------------------------------------- + # decrypt / encrypt credentials (lines 70-98) + # ----------------------------------------------------------------------- + + def test_should_decrypt_secret_fields_when_decrypting_api_key_credentials(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "enc_val"} + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.decrypt_datasource_provider_credentials("t1", p, "org/plug", "prov") + assert result["sk"] == "dec_tok" + + def test_should_encrypt_secret_fields_when_encrypting_api_key_credentials(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.encrypt_datasource_provider_credentials("t1", "prov", "org/plug", {"sk": "plain"}, p) + assert result["sk"] == "enc_tok" + self._enc.encrypt_token.assert_called() + + # ----------------------------------------------------------------------- + # get_datasource_credentials (lines 113-165) + # ----------------------------------------------------------------------- + + def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user): + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + mock_db_session.query().first.return_value = None + assert service.get_datasource_credentials("t1", "prov", "org/plug") == {} + + def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user): + """Expired OAuth credential (expires_at near zero) triggers a silent refresh.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "oauth2" + p.expires_at = 0 # expired + p.encrypted_credentials = {"tok": "x"} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "get_oauth_client", return_value={"oc": "v"}), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}), + ): + service.get_datasource_credentials("t1", "prov", "org/plug") + mock_db_session.commit.assert_called_once() + + def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user): + """API key credentials with expires_at=-1 skip refresh and return directly.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.expires_at = -1 # sentinel: never expires + p.encrypted_credentials = {"k": "v"} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}), + ): + result = service.get_datasource_credentials("t1", "prov", "org/plug") + assert result == {"k": "plain"} + + def test_should_fetch_by_credential_id_when_provided(self, service, mock_db_session, mock_user): + """When credential_id is passed, the credential_id filter path (line 113) is taken.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.expires_at = -1 + p.encrypted_credentials = {} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}), + ): + result = service.get_datasource_credentials("t1", "prov", "org/plug", credential_id="cred-id") + assert result == {"k": "v"} + + # ----------------------------------------------------------------------- + # get_all_datasource_credentials_by_provider (lines 176-228) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user): + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + mock_db_session.query().all.return_value = [] + assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == [] + + def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "oauth2" + p.expires_at = 0 + p.encrypted_credentials = {"t": "x"} + mock_db_session.query().all.return_value = [p] + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "get_oauth_client", return_value={"oc": "v"}), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}), + ): + result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # update_datasource_provider_name (lines 236-303) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with pytest.raises(ValueError, match="not found"): + service.update_datasource_provider_name("t1", make_id(), "new", "cred-id") + + def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "same" + mock_db_session.query().first.return_value = p + service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") + mock_db_session.commit.assert_not_called() + + def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.is_default = False + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 # conflict + with pytest.raises(ValueError, match="already exists"): + service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") + + def test_should_update_name_and_commit_when_no_conflict(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.is_default = False + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") + assert p.name == "new_name" + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # set_default_datasource_provider (lines 277-303) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with pytest.raises(ValueError, match="not found"): + service.set_default_datasource_provider("t1", make_id(), "bad-id") + + def test_should_mark_target_as_default_and_commit(self, service, mock_db_session): + target = MagicMock(spec=DatasourceProvider) + target.provider = "provider" + target.plugin_id = "org/plug" + mock_db_session.query().first.return_value = target + service.set_default_datasource_provider("t1", make_id(), "new-id") + assert target.is_default is True + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # get_oauth_encrypter (lines 404-420) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_oauth_schema_missing(self, service): + pm = MagicMock() + pm.declaration.oauth_schema = None + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + with pytest.raises(ValueError, match="oauth schema not found"): + service.get_oauth_encrypter("t1", make_id()) + + def test_should_return_encrypter_when_oauth_schema_exists(self, service): + schema_item = MagicMock() + schema_item.to_basic_provider_config.return_value = MagicMock() + pm = MagicMock() + pm.declaration.oauth_schema.client_schema = [schema_item] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm), + patch( + "services.datasource_provider_service.create_provider_encrypter", + return_value=(MagicMock(), MagicMock()), + ), + ): + result = service.get_oauth_encrypter("t1", make_id()) + assert result is not None + + # ----------------------------------------------------------------------- + # get_tenant_oauth_client (lines 381-402) + # ----------------------------------------------------------------------- + + def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session): + tenant_params = MagicMock() + tenant_params.client_params = {"k": "v"} + mock_db_session.query().first.return_value = tenant_params + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_tenant_oauth_client("t1", make_id(), mask=True) + assert result == {"k": "mask"} + + def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session): + tenant_params = MagicMock() + tenant_params.client_params = {"k": "v"} + mock_db_session.query().first.return_value = tenant_params + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_tenant_oauth_client("t1", make_id(), mask=False) + assert result == {"k": "dec"} + + def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + assert service.get_tenant_oauth_client("t1", make_id()) is None + + # ----------------------------------------------------------------------- + # get_oauth_client (lines 423-457) + # ----------------------------------------------------------------------- + + def test_should_use_tenant_config_when_available(self, service, mock_db_session): + mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"}) + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_oauth_client("t1", make_id()) + assert result == {"k": "dec"} + + def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session): + mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider"), + patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True), + ): + result = service.get_oauth_client("t1", make_id()) + assert result == {"k": "sys"} + + def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session): + """Neither tenant nor system credentials → raises ValueError.""" + mock_db_session.query().first.side_effect = [None, None] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider"), + patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False), + ): + with pytest.raises(ValueError, match="Please configure oauth client params"): + service.get_oauth_client("t1", make_id()) + + # ----------------------------------------------------------------------- + # add_datasource_oauth_provider (lines 539-607) + # ----------------------------------------------------------------------- + + def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): + """Conflict on name results in auto-incremented name, not an error.""" + mock_db_session.query().count.return_value = 1 # conflict first, then auto-named + mock_db_session.query().all.return_value = [] + with ( + patch.object(service, "extract_secret_variables", return_value=[]), + patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"), + ): + service.add_datasource_oauth_provider("conflict", "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + + def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session): + """name=None causes auto-generation via generate_next_datasource_provider_name.""" + mock_db_session.query().count.return_value = 0 + mock_db_session.query().all.return_value = [] + with ( + patch.object(service, "extract_secret_variables", return_value=[]), + patch.object(service, "generate_next_datasource_provider_name", return_value="auto"), + ): + service.add_datasource_oauth_provider(None, "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + + def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=["secret_key"]): + service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"}) + self._enc.encrypt_token.assert_called() + + def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {}) + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # reauthorize_datasource_oauth_provider (lines 477-537) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with patch.object(service, "extract_secret_variables", return_value=[]): + with pytest.raises(ValueError, match="not found"): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id") + + def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") + mock_db_session.commit.assert_called_once() + + def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 # conflict + mock_db_session.query().all.return_value = [] + with patch.object(service, "extract_secret_variables", return_value=["tok"]): + service.reauthorize_datasource_oauth_provider( + "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" + ) + mock_db_session.commit.assert_called_once() + + def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=["tok"]): + service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id") + self._enc.encrypt_token.assert_called() + + def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # add_datasource_api_key_provider (lines 608-675) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user): + """explicit name supplied + conflict → raises ValueError immediately.""" + mock_db_session.query().count.return_value = 1 + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="already exists"): + service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"}) + + def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")), + patch.object(service, "extract_secret_variables", return_value=[]), + ): + with pytest.raises(ValueError, match="Failed to validate"): + service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"}) + + def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials"), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + ): + service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"}) + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials"), + patch.object(service, "extract_secret_variables", return_value=[]), + ): + service.add_datasource_api_key_provider(None, "t1", make_id(), {}) + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # extract_secret_variables (lines 666-699) + # ----------------------------------------------------------------------- + + def test_should_extract_secret_variable_names_for_api_key_schema(self, service): + schema = MagicMock() + schema.name = "my_secret" + schema.type = MagicMock() + schema.type.value = FormType.SECRET_INPUT # "secret-input" + pm = MagicMock() + pm.declaration.credentials_schema = [schema] + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.API_KEY) + assert "my_secret" in result + + def test_should_extract_secret_variable_names_for_oauth2_schema(self, service): + schema = MagicMock() + schema.name = "oauth_secret" + schema.type = MagicMock() + schema.type.value = FormType.SECRET_INPUT + pm = MagicMock() + pm.declaration.oauth_schema.credentials_schema = [schema] + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.OAUTH2) + assert "oauth_secret" in result + + def test_should_raise_value_error_when_credential_type_is_invalid(self, service): + pm = MagicMock() + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + with pytest.raises(ValueError, match="Invalid credential type"): + service.extract_secret_variables("t1", "org/plug/prov", CredentialType.UNAUTHORIZED) + + # ----------------------------------------------------------------------- + # list_datasource_credentials (lines 721-754) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session): + mock_db_session.query().all.return_value = [] + assert service.list_datasource_credentials("t1", "prov", "org/plug") == [] + + def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "v"} + p.is_default = False + mock_db_session.query().all.return_value = [p] + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.list_datasource_credentials("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # get_all_datasource_credentials (lines 808-871) + # ----------------------------------------------------------------------- + + def test_should_aggregate_credentials_for_non_hardcoded_plugin(self, service): + with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: + ds = MagicMock() + ds.provider = "prov" + ds.plugin_id = "org/plug" + ds.declaration.identity.label.model_dump.return_value = {"en_US": "Label"} + mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] + cred = {"credential": {"k": "v"}, "is_default": True} + with patch.object(service, "list_datasource_credentials", return_value=[cred]): + results = service.get_all_datasource_credentials("t1") + assert len(results) == 1 + + def test_should_include_oauth_schema_for_hardcoded_plugin_ids(self, service, mock_db_session): + """Lines 819-871: get_all_datasource_credentials covers hardcoded langgenius plugin IDs.""" + with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: + ds = MagicMock() + ds.plugin_id = "langgenius/firecrawl_datasource" + ds.provider = "firecrawl" + ds.plugin_unique_identifier = "pui" + ds.declaration.identity.icon = "icon" + ds.declaration.identity.name = "langgenius/firecrawl_datasource" + ds.declaration.identity.label.model_dump.return_value = {"en_US": "Firecrawl"} + ds.declaration.identity.description.model_dump.return_value = {"en_US": "desc"} + ds.declaration.identity.author = "langgenius" + ds.declaration.credentials_schema = [] + ds.declaration.oauth_schema.client_schema = [] + ds.declaration.oauth_schema.credentials_schema = [] + mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] + with ( + patch.object(service, "list_datasource_credentials", return_value=[]), + patch.object(service, "get_tenant_oauth_client", return_value=None), + patch.object(service, "is_tenant_oauth_params_enabled", return_value=False), + patch.object(service, "is_system_oauth_params_exist", return_value=False), + ): + results = service.get_all_datasource_credentials("t1") + assert len(results) == 1 + assert results[0]["oauth_schema"] is not None + + # ----------------------------------------------------------------------- + # get_real_datasource_credentials (lines 873-915) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session): + mock_db_session.query().all.return_value = [] + assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == [] + + def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "v"} + mock_db_session.query().all.return_value = [p] + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.get_real_datasource_credentials("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # update_datasource_credentials (lines 917-978) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user): + mock_db_session.query().first.return_value = None + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="not found"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name") + + def test_should_raise_value_error_when_new_name_already_used_on_update(self, service, mock_db_session, mock_user): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "e"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="already exists"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name") + + def test_should_raise_value_error_when_credential_validation_fails_on_update( + self, service, mock_db_session, mock_user + ): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "e"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad")), + ): + with pytest.raises(ValueError, match="Failed to validate"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "v"}, "name") + + def test_should_encrypt_credentials_and_commit_when_update_succeeds(self, service, mock_db_session, mock_user): + """Verifies that encrypted_credentials is reassigned with encrypted value and commit is called.""" + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "old_enc"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + patch.object(service.provider_manager, "validate_provider_credentials"), + ): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "new_val"}, "name") + # encrypter must have been called with the new secret value + self._enc.encrypt_token.assert_called() + # commit must be called exactly once + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # remove_datasource_credentials (lines 980-997) + # ----------------------------------------------------------------------- + + def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + service.remove_datasource_credentials("t1", "id", "prov", "org/plug") + mock_db_session.delete.assert_called_once_with(p) + mock_db_session.commit.assert_called_once() + + def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session): + """No error raised; no delete called when record doesn't exist (lines 994 branch).""" + mock_db_session.query().first.return_value = None + service.remove_datasource_credentials("t1", "id", "prov", "org/plug") + mock_db_session.delete.assert_not_called() diff --git a/api/tests/unit_tests/services/test_end_user_service.py b/api/tests/unit_tests/services/test_end_user_service.py index 7f087a17d8..a3b1f46436 100644 --- a/api/tests/unit_tests/services/test_end_user_service.py +++ b/api/tests/unit_tests/services/test_end_user_service.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from models.model import App, EndUser +from models.model import App, DefaultEndUserSessionID, EndUser from services.end_user_service import EndUserService @@ -44,6 +44,145 @@ class TestEndUserServiceFactory: return end_user +class TestEndUserServiceGetEndUserById: + """Unit tests for EndUserService.get_end_user_by_id method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory): + """Test successful retrieval of end user by ID.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = mock_end_user + + # Act + result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + assert result == mock_end_user + mock_session.query.assert_called_once_with(EndUser) + mock_query.where.assert_called_once() + mock_query.first.assert_called_once() + mock_context.__enter__.assert_called_once() + mock_context.__exit__.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class): + """Test retrieval of non-existent end user returns None.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act + result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + assert result is None + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class): + """Test that query parameters are correctly applied.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act + EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + # Verify the where clause was called with the correct conditions + call_args = mock_query.where.call_args[0] + assert len(call_args) == 3 + # Check that the conditions match the expected filters + # (We can't easily test the exact conditions without importing SQLAlchemy) + + +class TestEndUserServiceGetOrCreateEndUser: + """Unit tests for EndUserService.get_or_create_end_user method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") + def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory): + """Test get_or_create_end_user with specific user_id.""" + # Arrange + app_mock = factory.create_app_mock() + user_id = "user-123" + expected_end_user = factory.create_end_user_mock() + mock_get_or_create_by_type.return_value = expected_end_user + + # Act + result = EndUserService.get_or_create_end_user(app_mock, user_id) + + # Assert + assert result == expected_end_user + mock_get_or_create_by_type.assert_called_once_with( + InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id + ) + + @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") + def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory): + """Test get_or_create_end_user without user_id (None).""" + # Arrange + app_mock = factory.create_app_mock() + expected_end_user = factory.create_end_user_mock() + mock_get_or_create_by_type.return_value = expected_end_user + + # Act + result = EndUserService.get_or_create_end_user(app_mock, None) + + # Assert + assert result == expected_end_user + mock_get_or_create_by_type.assert_called_once_with( + InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None + ) + + class TestEndUserServiceGetOrCreateEndUserByType: """ Unit tests for EndUserService.get_or_create_end_user_by_type method. @@ -60,6 +199,191 @@ class TestEndUserServiceGetOrCreateEndUserByType: """Provide test data factory.""" return TestEndUserServiceFactory() + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory): + """Test creating a new end user with specific user_id.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None # No existing user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + # Verify new EndUser was created with correct parameters + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + added_user = mock_session.add.call_args[0][0] + assert added_user.tenant_id == tenant_id + assert added_user.app_id == app_id + assert added_user.type == type_enum + assert added_user.session_id == user_id + assert added_user.external_user_id == user_id + assert added_user._is_anonymous is False + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory): + """Test creating a new end user with default session ID.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = None + type_enum = InvokeFrom.WEB_APP + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None # No existing user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + added_user = mock_session.add.call_args[0][0] + assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert added_user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + @patch("services.end_user_service.logger") + def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory): + """Test retrieving existing user with same type.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user = factory.create_end_user_mock( + tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = existing_user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + assert result == existing_user + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + mock_logger.info.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + @patch("services.end_user_service.logger") + def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory): + """Test upgrading existing user with different type.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + old_type = InvokeFrom.WEB_APP + new_type = InvokeFrom.SERVICE_API + + existing_user = factory.create_end_user_mock( + tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = existing_user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + assert result == existing_user + assert existing_user.type == new_type + mock_session.commit.assert_called_once() + mock_logger.info.assert_called_once() + logger_call_args = mock_logger.info.call_args[0] + assert "Upgrading legacy EndUser" in logger_call_args[0] + # The old and new types are passed as separate arguments + assert mock_logger.info.call_args[0][1] == existing_user.id + assert mock_logger.info.call_args[0][2] == old_type + assert mock_logger.info.call_args[0][3] == new_type + assert mock_logger.info.call_args[0][4] == user_id + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory): + """Test that query ordering prioritizes exact type matches.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + target_type = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None + + # Act + EndUserService.get_or_create_end_user_by_type( + type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + mock_query.order_by.assert_called_once() + # Verify that case statement is used for ordering + order_by_call = mock_query.order_by.call_args[0][0] + # The exact structure depends on SQLAlchemy's case implementation + # but we can verify it was called + # Test 10: Session context manager properly closes @patch("services.end_user_service.Session") @patch("services.end_user_service.db") @@ -93,3 +417,425 @@ class TestEndUserServiceGetOrCreateEndUserByType: # Verify context manager was entered and exited mock_context.__enter__.assert_called_once() mock_context.__exit__.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_all_invokefrom_types_supported(self, mock_db, mock_session_class): + """Test that all InvokeFrom enum values are supported.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + + for invoke_type in InvokeFrom: + with patch("services.end_user_service.Session") as mock_session_class: + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + added_user = mock_session.add.call_args[0][0] + assert added_user.type == invoke_type + + +class TestEndUserServiceCreateEndUserBatch: + """Unit tests for EndUserService.create_end_user_batch method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_empty_app_ids(self, mock_db, mock_session_class): + """Test batch creation with empty app_ids list.""" + # Arrange + tenant_id = "tenant-123" + app_ids: list[str] = [] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert result == {} + mock_session_class.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_default_session_id(self, mock_db, mock_session_class): + """Test batch creation with empty user_id (uses default session).""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + user_id = "" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 2 + for app_id, end_user in result.items(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class): + """Test that duplicate app_ids are deduplicated while preserving order.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + # Should have 3 unique app_ids in original order + assert len(result) == 3 + assert "app-456" in result + assert "app-789" in result + assert "app-123" in result + + # Verify the order is preserved + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 3 + assert added_users[0].app_id == "app-456" + assert added_users[1].app_id == "app-789" + assert added_users[2].app_id == "app-123" + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory): + """Test batch creation when all users already exist.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user1 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + existing_user2 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1, existing_user2] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 2 + assert result["app-456"] == existing_user1 + assert result["app-789"] == existing_user2 + mock_session.add_all.assert_not_called() + mock_session.commit.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory): + """Test batch creation with some existing and some new users.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-123"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user1 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + # app-789 and app-123 don't exist + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 3 + assert result["app-456"] == existing_user1 + assert "app-789" in result + assert "app-123" in result + + # Should create 2 new users + mock_session.add_all.assert_called_once() + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 2 + + mock_session.commit.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory): + """Test batch creation handles duplicates in existing users gracefully.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + # Simulate duplicate records in database + existing_user1 = factory.create_end_user_mock( + user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + existing_user2 = factory.create_end_user_mock( + user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1, existing_user2] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 1 + # Should prefer the first one found + assert result["app-456"] == existing_user1 + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class): + """Test batch creation with all InvokeFrom types.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + + for invoke_type in InvokeFrom: + with patch("services.end_user_service.Session") as mock_session_class: + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + added_user = mock_session.add_all.call_args[0][0][0] + assert added_user.type == invoke_type + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory): + """Test batch creation with single app_id.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 1 + assert "app-456" in result + mock_session.add_all.assert_called_once() + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 1 + assert added_users[0].app_id == "app-456" + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class): + """Test batch creation correctly sets anonymous flag.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + + # Test with regular user ID + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act - authenticated user + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789" + ) + + # Assert + added_users = mock_session.add_all.call_args[0][0] + for user in added_users: + assert user._is_anonymous is False + + # Test with default session ID + mock_session.reset_mock() + mock_query.reset_mock() + mock_query.all.return_value = [] + + # Act - anonymous user + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=app_ids, + user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID, + ) + + # Assert + added_users = mock_session.add_all.call_args[0][0] + for user in added_users: + assert user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_efficient_single_query(self, mock_db, mock_session_class): + """Test that batch creation uses efficient single query for existing users.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-123"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) + + # Assert + # Should make exactly one query to check for existing users + mock_session.query.assert_called_once_with(EndUser) + mock_query.where.assert_called_once() + mock_query.all.assert_called_once() + + # Verify the where clause uses .in_() for app_ids + where_call = mock_query.where.call_args[0] + # The exact structure depends on SQLAlchemy implementation + # but we can verify it was called with the right parameters + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_session_context_manager(self, mock_db, mock_session_class): + """Test that batch creation properly uses session context manager.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) + + # Assert + mock_context.__enter__.assert_called_once() + mock_context.__exit__.assert_called_once() + mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_export_app_messages.py b/api/tests/unit_tests/services/test_export_app_messages.py new file mode 100644 index 0000000000..5f2d3f21c0 --- /dev/null +++ b/api/tests/unit_tests/services/test_export_app_messages.py @@ -0,0 +1,43 @@ +import datetime + +import pytest + +from services.retention.conversation.message_export_service import AppMessageExportService + + +def test_validate_export_filename_accepts_relative_path(): + assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01" + + +@pytest.mark.parametrize( + "filename", + [ + "test01.jsonl.gz", + "test01.jsonl", + "test01.gz", + "/tmp/test01", + "exports/../test01", + "bad\x00name", + "bad\tname", + "a" * 1025, + ], +) +def test_validate_export_filename_rejects_invalid_values(filename: str): + with pytest.raises(ValueError): + AppMessageExportService.validate_export_filename(filename) + + +def test_service_derives_output_names_from_filename_base(): + service = AppMessageExportService( + app_id="736b9b03-20f2-4697-91da-8d00f6325900", + start_from=None, + end_before=datetime.datetime(2026, 3, 1), + filename="exports/2026/test01", + batch_size=1000, + use_cloud_storage=True, + dry_run=True, + ) + + assert service._filename_base == "exports/2026/test01" + assert service.output_gz_name == "exports/2026/test01.jsonl.gz" + assert service.output_jsonl_name == "exports/2026/test01.jsonl" diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py new file mode 100644 index 0000000000..b7259c3e82 --- /dev/null +++ b/api/tests/unit_tests/services/test_file_service.py @@ -0,0 +1,420 @@ +import base64 +import hashlib +import os +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker +from werkzeug.exceptions import NotFound + +from configs import dify_config +from models.enums import CreatorUserRole +from models.model import Account, EndUser, UploadFile +from services.errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError +from services.file_service import FileService + + +class TestFileService: + @pytest.fixture + def mock_db_session(self): + session = MagicMock(spec=Session) + # Mock context manager behavior + session.__enter__.return_value = session + return session + + @pytest.fixture + def mock_session_maker(self, mock_db_session): + maker = MagicMock(spec=sessionmaker) + maker.return_value = mock_db_session + return maker + + @pytest.fixture + def file_service(self, mock_session_maker): + return FileService(session_factory=mock_session_maker) + + def test_init_with_engine(self): + engine = MagicMock(spec=Engine) + service = FileService(session_factory=engine) + assert isinstance(service._session_maker, sessionmaker) + + def test_init_with_sessionmaker(self): + maker = MagicMock(spec=sessionmaker) + service = FileService(session_factory=maker) + assert service._session_maker == maker + + def test_init_invalid_factory(self): + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + FileService(session_factory="invalid") + + @patch("services.file_service.storage") + @patch("services.file_service.naive_utc_now") + @patch("services.file_service.extract_tenant_id") + @patch("services.file_service.file_helpers.get_signed_file_url") + def test_upload_file_success( + self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service, mock_db_session + ): + # Setup + mock_tenant_id.return_value = "tenant_id" + mock_now.return_value = "2024-01-01" + mock_get_url.return_value = "http://signed-url" + + user = MagicMock(spec=Account) + user.id = "user_id" + content = b"file content" + filename = "test.jpg" + mimetype = "image/jpeg" + + # Execute + result = file_service.upload_file(filename=filename, content=content, mimetype=mimetype, user=user) + + # Assert + assert isinstance(result, UploadFile) + assert result.name == filename + assert result.tenant_id == "tenant_id" + assert result.size == len(content) + assert result.extension == "jpg" + assert result.mime_type == mimetype + assert result.created_by_role == CreatorUserRole.ACCOUNT + assert result.created_by == "user_id" + assert result.hash == hashlib.sha3_256(content).hexdigest() + assert result.source_url == "http://signed-url" + + mock_storage.save.assert_called_once() + mock_db_session.add.assert_called_once_with(result) + mock_db_session.commit.assert_called_once() + + def test_upload_file_invalid_characters(self, file_service): + with pytest.raises(ValueError, match="Filename contains invalid characters"): + file_service.upload_file(filename="invalid/file.txt", content=b"", mimetype="text/plain", user=MagicMock()) + + def test_upload_file_long_filename(self, file_service, mock_db_session): + # Setup + long_name = "a" * 210 + ".txt" + user = MagicMock(spec=Account) + user.id = "user_id" + + with ( + patch("services.file_service.storage"), + patch("services.file_service.extract_tenant_id") as mock_tenant, + patch("services.file_service.file_helpers.get_signed_file_url"), + ): + mock_tenant.return_value = "tenant" + result = file_service.upload_file(filename=long_name, content=b"test", mimetype="text/plain", user=user) + assert len(result.name) <= 205 # 200 + . + extension + assert result.name.endswith(".txt") + + def test_upload_file_blocked_extension(self, file_service): + with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe"): + with pytest.raises(BlockedFileExtensionError): + file_service.upload_file( + filename="test.exe", content=b"", mimetype="application/octet-stream", user=MagicMock() + ) + + def test_upload_file_unsupported_type_for_datasets(self, file_service): + with pytest.raises(UnsupportedFileTypeError): + file_service.upload_file( + filename="test.jpg", content=b"", mimetype="image/jpeg", user=MagicMock(), source="datasets" + ) + + def test_upload_file_too_large(self, file_service): + # 16MB file for an image with 15MB limit + content = b"a" * (16 * 1024 * 1024) + with patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 15): + with pytest.raises(FileTooLargeError): + file_service.upload_file(filename="test.jpg", content=content, mimetype="image/jpeg", user=MagicMock()) + + def test_upload_file_end_user(self, file_service, mock_db_session): + user = MagicMock(spec=EndUser) + user.id = "end_user_id" + + with ( + patch("services.file_service.storage"), + patch("services.file_service.extract_tenant_id") as mock_tenant, + patch("services.file_service.file_helpers.get_signed_file_url"), + ): + mock_tenant.return_value = "tenant" + result = file_service.upload_file(filename="test.txt", content=b"test", mimetype="text/plain", user=user) + assert result.created_by_role == CreatorUserRole.END_USER + + def test_is_file_size_within_limit(self): + with ( + patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 10), + patch.object(dify_config, "UPLOAD_VIDEO_FILE_SIZE_LIMIT", 20), + patch.object(dify_config, "UPLOAD_AUDIO_FILE_SIZE_LIMIT", 30), + patch.object(dify_config, "UPLOAD_FILE_SIZE_LIMIT", 5), + ): + # Image + assert FileService.is_file_size_within_limit(extension="jpg", file_size=10 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="png", file_size=11 * 1024 * 1024) is False + + # Video + assert FileService.is_file_size_within_limit(extension="mp4", file_size=20 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="avi", file_size=21 * 1024 * 1024) is False + + # Audio + assert FileService.is_file_size_within_limit(extension="mp3", file_size=30 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="wav", file_size=31 * 1024 * 1024) is False + + # Default + assert FileService.is_file_size_within_limit(extension="txt", file_size=5 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="pdf", file_size=6 * 1024 * 1024) is False + + def test_get_file_base64_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "test_key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load_once.return_value = b"test content" + + # Execute + result = file_service.get_file_base64("file_id") + + # Assert + assert result == base64.b64encode(b"test content").decode() + mock_storage.load_once.assert_called_once_with("test_key") + + def test_get_file_base64_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_base64("non_existent") + + def test_upload_text_success(self, file_service, mock_db_session): + # Setup + text = "sample text" + text_name = "test.txt" + user_id = "user_id" + tenant_id = "tenant_id" + + with patch("services.file_service.storage") as mock_storage: + # Execute + result = file_service.upload_text(text, text_name, user_id, tenant_id) + + # Assert + assert result.name == text_name + assert result.size == len(text) + assert result.tenant_id == tenant_id + assert result.created_by == user_id + assert result.used is True + assert result.extension == "txt" + mock_storage.save.assert_called_once() + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_upload_text_long_name(self, file_service, mock_db_session): + long_name = "a" * 210 + with patch("services.file_service.storage"): + result = file_service.upload_text("text", long_name, "user", "tenant") + assert len(result.name) == 200 + + def test_get_file_preview_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "pdf" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract: + mock_extract.return_value = "Extracted text content" + + # Execute + result = file_service.get_file_preview("file_id") + + # Assert + assert result == "Extracted text content" + + def test_get_file_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_preview("non_existent") + + def test_get_file_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "exe" + mock_db_session.query().where().first.return_value = upload_file + with pytest.raises(UnsupportedFileTypeError): + file_service.get_file_preview("file_id") + + def test_get_image_preview_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "jpg" + upload_file.mime_type = "image/jpeg" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with ( + patch("services.file_service.file_helpers.verify_image_signature") as mock_verify, + patch("services.file_service.storage") as mock_storage, + ): + mock_verify.return_value = True + mock_storage.load.return_value = iter([b"chunk1"]) + + # Execute + gen, mime = file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + # Assert + assert list(gen) == [b"chunk1"] + assert mime == "image/jpeg" + + def test_get_image_preview_invalid_sig(self, file_service): + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = False + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_image_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_image_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "txt" + mock_db_session.query().where().first.return_value = upload_file + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(UnsupportedFileTypeError): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_file_generator_by_file_id_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with ( + patch("services.file_service.file_helpers.verify_file_signature") as mock_verify, + patch("services.file_service.storage") as mock_storage, + ): + mock_verify.return_value = True + mock_storage.load.return_value = iter([b"chunk"]) + + gen, file = file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + assert list(gen) == [b"chunk"] + assert file == upload_file + + def test_get_file_generator_by_file_id_invalid_sig(self, file_service): + with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: + mock_verify.return_value = False + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + + def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + + def test_get_public_image_preview_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "png" + upload_file.mime_type = "image/png" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load.return_value = b"image content" + gen, mime = file_service.get_public_image_preview("file_id") + assert gen == b"image content" + assert mime == "image/png" + + def test_get_public_image_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_public_image_preview("file_id") + + def test_get_public_image_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "txt" + mock_db_session.query().where().first.return_value = upload_file + with pytest.raises(UnsupportedFileTypeError): + file_service.get_public_image_preview("file_id") + + def test_get_file_content_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load.return_value = b"hello world" + result = file_service.get_file_content("file_id") + assert result == "hello world" + + def test_get_file_content_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_content("file_id") + + def test_delete_file_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + # For session.scalar(select(...)) + mock_db_session.scalar.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + file_service.delete_file("file_id") + mock_storage.delete.assert_called_once_with("key") + mock_db_session.delete.assert_called_once_with(upload_file) + + def test_delete_file_not_found(self, file_service, mock_db_session): + mock_db_session.scalar.return_value = None + file_service.delete_file("file_id") + # Should return without doing anything + + @patch("services.file_service.db") + def test_get_upload_files_by_ids_empty(self, mock_db): + result = FileService.get_upload_files_by_ids("tenant_id", []) + assert result == {} + + @patch("services.file_service.db") + def test_get_upload_files_by_ids(self, mock_db): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "550e8400-e29b-41d4-a716-446655440000" + upload_file.tenant_id = "tenant_id" + mock_db.session.scalars().all.return_value = [upload_file] + + result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"]) + assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file + + def test_sanitize_zip_entry_name(self): + assert FileService._sanitize_zip_entry_name("path/to/file.txt") == "file.txt" + assert FileService._sanitize_zip_entry_name("../../../etc/passwd") == "passwd" + assert FileService._sanitize_zip_entry_name(" ") == "file" + assert FileService._sanitize_zip_entry_name("a\\b") == "a_b" + + def test_dedupe_zip_entry_name(self): + used = {"a.txt"} + assert FileService._dedupe_zip_entry_name("b.txt", used) == "b.txt" + assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (1).txt" + used.add("a (1).txt") + assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (2).txt" + + def test_build_upload_files_zip_tempfile(self): + upload_file = MagicMock(spec=UploadFile) + upload_file.name = "test.txt" + upload_file.key = "key" + + with ( + patch("services.file_service.storage") as mock_storage, + patch("services.file_service.os.remove") as mock_remove, + ): + mock_storage.load.return_value = [b"chunk1", b"chunk2"] + + with FileService.build_upload_files_zip_tempfile(upload_files=[upload_file]) as tmp_path: + assert os.path.exists(tmp_path) + + mock_remove.assert_called_once() diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/unit_tests/services/test_hit_testing_service.py new file mode 100644 index 0000000000..80e9729f5b --- /dev/null +++ b/api/tests/unit_tests/services/test_hit_testing_service.py @@ -0,0 +1,385 @@ +import json +from typing import Any, cast +from unittest.mock import ANY, MagicMock, patch + +import pytest + +from core.rag.models.document import Document +from models.dataset import Dataset +from services.hit_testing_service import HitTestingService + + +class TestHitTestingService: + """Test suite for HitTestingService""" + + # ===== Utility Method Tests ===== + + def test_escape_query_for_search_should_escape_double_quotes(self): + """Test that escape_query_for_search escapes double quotes correctly""" + # Arrange + query = 'test "query" with quotes' + expected = 'test \\"query\\" with quotes' + + # Act + result = HitTestingService.escape_query_for_search(query) + + # Assert + assert result == expected + + def test_hit_testing_args_check_should_pass_with_valid_query(self): + """Test that hit_testing_args_check passes with a valid query""" + # Arrange + args = {"query": "valid query"} + + # Act & Assert (should not raise) + HitTestingService.hit_testing_args_check(args) + + def test_hit_testing_args_check_should_pass_with_valid_attachments(self): + """Test that hit_testing_args_check passes with valid attachment_ids""" + # Arrange + args = {"attachment_ids": ["id1", "id2"]} + + # Act & Assert (should not raise) + HitTestingService.hit_testing_args_check(args) + + def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): + """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" + # Arrange + args = {} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Query or attachment_ids is required" in str(exc_info.value) + + def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): + """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" + # Arrange + args = {"query": "a" * 251} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Query cannot exceed 250 characters" in str(exc_info.value) + + def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): + """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" + # Arrange + args = {"attachment_ids": "not a list"} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Attachment_ids must be a list" in str(exc_info.value) + + # ===== Response Formatting Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") + def test_compact_retrieve_response_should_format_correctly(self, mock_format): + """Test that compact_retrieve_response formats the response correctly""" + # Arrange + query = "test query" + mock_doc = MagicMock(spec=Document) + documents = [mock_doc] + + mock_record = MagicMock() + mock_record.model_dump.return_value = {"content": "formatted content"} + mock_format.return_value = [mock_record] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert len(result["records"]) == 1 + assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" + mock_format.assert_called_once_with(documents) + + def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): + """Test that compact_external_retrieve_response returns records when dataset provider is external""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "external" + query = "test query" + documents = [ + {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, + {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, + ] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert len(result["records"]) == 2 + assert cast(dict[str, Any], result["records"][0])["content"] == "c1" + assert cast(dict[str, Any], result["records"][1])["title"] == "t2" + + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): + """Test that compact_external_retrieve_response returns empty records for non-external provider""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "not_external" + query = "test query" + documents = [{"content": "c1"}] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert result["records"] == [] + + # ===== External Retrieve Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): + """Test that external_retrieve successfully retrieves from external provider and commits query""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + dataset.provider = "external" + query = 'test "query"' + account = MagicMock() + account.id = "account_id" + + mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] + + # Act + result = cast( + dict[str, Any], + HitTestingService.external_retrieve( + dataset=dataset, + query=query, + account=account, + external_retrieval_model={"model": "test"}, + metadata_filtering_conditions={"key": "val"}, + ), + ) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" + + # Verify call to RetrievalService.external_retrieve with escaped query + mock_ext_retrieve.assert_called_once_with( + dataset_id="dataset_id", + query='test \\"query\\"', + external_retrieval_model={"model": "test"}, + metadata_filtering_conditions={"key": "val"}, + ) + + # Verify DatasetQuery record was added and committed + mock_add.assert_called_once() + mock_commit.assert_called_once() + + def test_external_retrieve_should_return_empty_for_non_external_provider(self): + """Test that external_retrieve returns empty results immediately if provider is not external""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "not_external" + query = "test query" + account = MagicMock() + + # Act + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert result["records"] == [] + + # ===== Retrieve Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve uses default model when retrieval_model is not provided""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + dataset.retrieval_model = None + query = "test query" + account = MagicMock() + account.id = "account_id" + + mock_retrieve.return_value = [] + + # Act + result = cast( + dict[str, Any], + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + ), + ) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + mock_retrieve.assert_called_once() + # Verify top_k from default_retrieval_model (4) + assert mock_retrieve.call_args.kwargs["top_k"] == 4 + mock_commit.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): + """Test that retrieve correctly calls metadata filtering when conditions are present""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + + retrieval_model = { + "search_method": "semantic_search", + "metadata_filtering_conditions": {"some": "condition"}, + "top_k": 5, + "reranking_enable": False, + "score_threshold_enabled": False, + } + + # Mock metadata filtering response + mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + ) + + # Assert + mock_get_meta.assert_called_once() + mock_retrieve.assert_called_once() + assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") + def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): + """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + + retrieval_model = { + "search_method": "semantic_search", + "metadata_filtering_conditions": {"some": "condition"}, + "top_k": 5, + "reranking_enable": False, + "score_threshold_enabled": False, + } + + # Mock metadata filtering response: condition returned but no IDs + mock_get_meta.return_value = ({}, "condition_string") + + # Act + result = cast( + dict[str, Any], + HitTestingService.retrieve( + dataset=dataset, + query=query, + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, + ), + ) + + # Assert + assert result["records"] == [] + mock_retrieve.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + attachment_ids = ["att1", "att2"] + + retrieval_model = { + "search_method": "semantic_search", + "top_k": 4, + "reranking_enable": False, + "score_threshold_enabled": False, + } + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, + query=query, + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, + attachment_ids=attachment_ids, + ) + + # Assert + mock_retrieve.assert_called_once_with( + retrieval_method=ANY, + dataset_id="dataset_id", + query=query, + attachment_ids=attachment_ids, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + ) + # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) + # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) + called_query = mock_add.call_args[0][0] + query_content = json.loads(called_query.content) + assert len(query_content) == 3 # 1 text + 2 images + assert query_content[0]["content_type"] == "text_query" + assert query_content[1]["content_type"] == "image_query" + assert query_content[1]["content"] == "att1" + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve passes reranking and threshold parameters correctly""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + + retrieval_model = { + "search_method": "hybrid_search", + "top_k": 10, + "reranking_enable": True, + "reranking_model": {"provider": "test"}, + "reranking_mode": "weighted_sum", + "score_threshold_enabled": True, + "score_threshold": 0.5, + "weights": {"vector": 0.5, "keyword": 0.5}, + } + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + ) + + # Assert + mock_retrieve.assert_called_once() + kwargs = mock_retrieve.call_args.kwargs + assert kwargs["score_threshold"] == 0.5 + assert kwargs["reranking_model"] == {"provider": "test"} + assert kwargs["reranking_mode"] == "weighted_sum" + assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5} diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index e64d3c5406..74139fd12d 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -1,97 +1,291 @@ from types import SimpleNamespace +from unittest.mock import MagicMock, patch import pytest +from sqlalchemy.engine import Engine +from configs import dify_config from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, + MemberRecipient, ) from dify_graph.runtime import VariablePool from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, + DeliveryTestEmailRecipient, DeliveryTestError, + DeliveryTestRegistry, + DeliveryTestResult, + DeliveryTestStatus, + DeliveryTestUnsupportedError, EmailDeliveryTestHandler, + HumanInputDeliveryTestService, + _build_form_link, ) -def _make_email_method() -> EmailDeliveryMethod: - return EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="tester@example.com")], - ), - subject="Test subject", - body="Test body", +@pytest.fixture +def mock_db(monkeypatch): + mock_db = MagicMock() + monkeypatch.setattr(service_module, "db", mock_db) + return mock_db + + +def _make_valid_email_config(): + return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body") + + +def test_build_form_link(): + with patch.object(dify_config, "APP_WEB_URL", "http://example.com/"): + assert _build_form_link("token123") == "http://example.com/form/token123" + + with patch.object(dify_config, "APP_WEB_URL", "http://example.com"): + assert _build_form_link("token123") == "http://example.com/form/token123" + + assert _build_form_link(None) is None + + with patch.object(dify_config, "APP_WEB_URL", None): + assert _build_form_link("token123") is None + + +class TestDeliveryTestRegistry: + def test_register(self): + registry = DeliveryTestRegistry() + assert len(registry._handlers) == 0 + handler = MagicMock() + registry.register(handler) + assert len(registry._handlers) == 1 + assert registry._handlers[0] == handler + + def test_register_and_dispatch(self): + handler = MagicMock() + handler.supports.return_value = True + handler.send_test.return_value = DeliveryTestResult(status=DeliveryTestStatus.OK) + + registry = DeliveryTestRegistry([handler]) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + result = registry.dispatch(context=context, method=method) + + assert result.status == DeliveryTestStatus.OK + handler.supports.assert_called_once_with(method) + handler.send_test.assert_called_once_with(context=context, method=method) + + def test_dispatch_unsupported(self): + handler = MagicMock() + handler.supports.return_value = False + + registry = DeliveryTestRegistry([handler]) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): + registry.dispatch(context=context, method=method) + + def test_default(self, mock_db): + registry = DeliveryTestRegistry.default() + assert len(registry._handlers) == 1 + assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) + + +def test_human_input_delivery_test_service(): + registry = MagicMock(spec=DeliveryTestRegistry) + service = HumanInputDeliveryTestService(registry=registry) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + service.send_test(context=context, method=method) + registry.dispatch.assert_called_once_with(context=context, method=method) + + +class TestEmailDeliveryTestHandler: + def test_init_with_engine(self): + engine = MagicMock(spec=Engine) + handler = EmailDeliveryTestHandler(session_factory=engine) + assert handler._session_factory.kw["bind"] == engine + + def test_supports(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + assert handler.supports(method) is True + assert handler.supports(MagicMock()) is False + + def test_send_test_unsupported_method(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + with pytest.raises(DeliveryTestUnsupportedError): + handler.send_test(context=MagicMock(), method=MagicMock()) + + def test_send_test_feature_disabled(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), ) - ) - - -def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - ) - method = _make_email_method() - - with pytest.raises(DeliveryTestError, match="Email delivery is not available"): - handler.send_test(context=context, method=method) - - -def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): - class DummyMail: - def __init__(self): - self.sent: list[dict[str, str]] = [] - - def is_inited(self) -> bool: - return True - - def send(self, *, to: str, subject: str, html: str): - self.sent.append({"to": to, "subject": subject, "html": html}) - - mail = DummyMail() - monkeypatch.setattr(service_module, "mail", mail) - monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template) - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment] - - method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]), - subject="Subject", - body="Value {{#node1.value#}}", + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" ) - ) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "OK") - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - variable_pool=variable_pool, - ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) - handler.send_test(context=context, method=method) + with pytest.raises(DeliveryTestError, match="Email delivery is not available"): + handler.send_test(context=context, method=method) - assert mail.sent[0]["html"] == "Value OK" + def test_send_test_mail_not_inited(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: False) + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" + ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + with pytest.raises(DeliveryTestError, match="Mail client is not initialized."): + handler.send_test(context=context, method=method) + + def test_send_test_no_recipients(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + handler._resolve_recipients = MagicMock(return_value=[]) + + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" + ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + with pytest.raises(DeliveryTestError, match="No recipients configured"): + handler.send_test(context=context, method=method) + + def test_send_test_success(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) + mock_mail_send = MagicMock() + monkeypatch.setattr(service_module.mail, "send", mock_mail_send) + monkeypatch.setattr(service_module, "render_email_template", lambda t, s: f"RENDERED_{t}") + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + handler._resolve_recipients = MagicMock(return_value=["test@example.com"]) + + variable_pool = VariablePool() + context = DeliveryTestContext( + tenant_id="t1", + app_id="a1", + node_id="n1", + node_title="title", + rendered_content="content", + variable_pool=variable_pool, + recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], + ) + + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + result = handler.send_test(context=context, method=method) + + assert result.status == DeliveryTestStatus.OK + assert result.delivered_to == ["test@example.com"] + mock_mail_send.assert_called_once() + args, kwargs = mock_mail_send.call_args + assert kwargs["to"] == "test@example.com" + assert "RENDERED_Subj" in kwargs["subject"] + + def test_resolve_recipients(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + + # Test Case 1: External Recipient + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False), + subject="", + body="", + ) + ) + assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] + + # Test Case 2: Member Recipient + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False), + subject="", + body="", + ) + ) + handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"}) + assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"] + + # Test Case 3: Whole Workspace + method = EmailDeliveryMethod( + config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="") + ) + handler._query_workspace_member_emails = MagicMock( + return_value={"u1": "u1@example.com", "u2": "u2@example.com"} + ) + recipients = handler._resolve_recipients(tenant_id="t1", method=method) + assert set(recipients) == {"u1@example.com", "u2@example.com"} + + def test_query_workspace_member_emails(self): + mock_session = MagicMock() + mock_session_factory = MagicMock(return_value=mock_session) + mock_session.__enter__.return_value = mock_session + + handler = EmailDeliveryTestHandler(session_factory=mock_session_factory) + + # Empty user_ids + assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {} + + # user_ids is None (all) + mock_execute = MagicMock() + mock_session.execute.return_value = mock_execute + mock_execute.all.return_value = [("u1", "u1@example.com")] + + result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) + assert result == {"u1": "u1@example.com"} + + # user_ids with values + result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"]) + assert result == {"u1": "u1@example.com"} + + def test_build_substitutions(self): + context = DeliveryTestContext( + tenant_id="t1", + app_id="a1", + node_id="n1", + node_title="title", + rendered_content="content", + template_vars={"custom": "var"}, + recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], + ) + + subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") + + assert subs["node_title"] == "title" + assert subs["form_content"] == "content" + assert subs["recipient_email"] == "test@example.com" + assert subs["custom"] == "var" + assert subs["form_token"] == "token123" + assert "form/token123" in subs["form_link"] + + # Without matching recipient + subs_no_match = EmailDeliveryTestHandler._build_substitutions( + context=context, recipient_email="other@example.com" + ) + assert subs_no_match["form_token"] == "" + assert subs_no_match["form_link"] == "" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index a4c6c50593..375e47d7fc 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -16,7 +16,13 @@ from dify_graph.nodes.human_input.entities import ( ) from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus from models.human_input import RecipientType -from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError +from services.human_input_service import ( + Form, + FormExpiredError, + FormSubmittedError, + HumanInputService, + InvalidFormDataError, +) @pytest.fixture @@ -285,3 +291,172 @@ def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_fa assert "Missing required inputs" in str(exc_info.value) repo.mark_submitted.assert_not_called() + + +def test_form_properties(sample_form_record): + form = Form(sample_form_record) + assert form.id == "form-id" + assert form.workflow_run_id == "workflow-run-id" + assert form.tenant_id == "tenant-id" + assert form.app_id == "app-id" + assert form.recipient_id == "recipient-id" + assert form.recipient_type == RecipientType.STANDALONE_WEB_APP + assert form.status == HumanInputFormStatus.WAITING + assert form.form_kind == HumanInputFormKind.RUNTIME + assert isinstance(form.created_at, datetime) + assert isinstance(form.expiration_time, datetime) + + +def test_form_submitted_error_init(): + error = FormSubmittedError(form_id="test-form") + assert "form_id=test-form" in error.description + assert error.code == 412 + + +def test_human_input_service_init_with_engine(mocker): + engine = MagicMock(spec=human_input_service_module.Engine) + sessionmaker_mock = mocker.patch("services.human_input_service.sessionmaker") + + HumanInputService(session_factory=engine) + sessionmaker_mock.assert_called_once_with(bind=engine) + + +def test_get_form_by_token_none(mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = None + + service = HumanInputService(session_factory, form_repository=repo) + assert service.get_form_by_token("invalid") is None + + +def test_get_form_definition_by_token_mismatch(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + service = HumanInputService(session_factory, form_repository=repo) + # RecipientType mismatch + assert service.get_form_definition_by_token(RecipientType.CONSOLE, "token") is None + + +def test_get_form_definition_by_token_success(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + service = HumanInputService(session_factory, form_repository=repo) + form = service.get_form_definition_by_token(RecipientType.STANDALONE_WEB_APP, "token") + assert form is not None + assert form.id == sample_form_record.form_id + + +def test_get_form_definition_by_token_for_console_mismatch(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record # is STANDALONE_WEB_APP + + service = HumanInputService(session_factory, form_repository=repo) + assert service.get_form_definition_by_token_for_console("token") is None + + +def test_submit_form_by_token_delivery_not_enabled(mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = None + + service = HumanInputService(session_factory, form_repository=repo) + with pytest.raises(human_input_service_module.WebAppDeliveryNotEnabledError): + service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "action", {}) + + +def test_submit_form_by_token_no_workflow_run_id(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + # Return record with no workflow_run_id + result_record = dataclasses.replace(sample_form_record, workflow_run_id=None) + repo.mark_submitted.return_value = result_record + + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "submit", {}) + enqueue_spy.assert_not_called() + + +def test_ensure_form_active_errors(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + # Submitted + submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + with pytest.raises(human_input_service_module.FormSubmittedError): + service.ensure_form_active(Form(submitted_record)) + + # Timeout status + timeout_record = dataclasses.replace(sample_form_record, status=HumanInputFormStatus.TIMEOUT) + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(timeout_record)) + + # Expired time + expired_time_record = dataclasses.replace( + sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1) + ) + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(expired_time_record)) + + +def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + + with pytest.raises(human_input_service_module.FormSubmittedError): + service._ensure_not_submitted(Form(submitted_record)) + + +def test_enqueue_resume_workflow_not_found(mocker, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = None + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + with pytest.raises(AssertionError) as excinfo: + service.enqueue_resume("workflow-run-id") + assert "WorkflowRun not found" in str(excinfo.value) + + +def test_enqueue_resume_app_not_found(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + session.execute.return_value.scalar_one_or_none.return_value = None + logger_spy = mocker.patch("services.human_input_service.logger") + + service.enqueue_resume("workflow-run-id") + logger_spy.error.assert_called_once() + + +def test_is_globally_expired_zero_timeout(monkeypatch, sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0) + assert service._is_globally_expired(Form(sample_form_record)) is False diff --git a/api/tests/unit_tests/services/test_knowledge_service.py b/api/tests/unit_tests/services/test_knowledge_service.py new file mode 100644 index 0000000000..bc0caee071 --- /dev/null +++ b/api/tests/unit_tests/services/test_knowledge_service.py @@ -0,0 +1,146 @@ +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pytest + +from services.knowledge_service import ExternalDatasetTestService + + +class TestKnowledgeService: + """Test suite for ExternalDatasetTestService""" + + # ===== Happy Path Tests ===== + + @patch("services.knowledge_service.boto3.client") + @patch("services.knowledge_service.dify_config") + def test_knowledge_retrieval_should_succeed_with_valid_results( + self, mock_dify_config: MagicMock, mock_boto_client: MagicMock + ): + """Test that knowledge_retrieval successfully parses results from Bedrock""" + # Arrange + mock_dify_config.AWS_SECRET_ACCESS_KEY = "dummy_secret" + mock_dify_config.AWS_ACCESS_KEY_ID = "dummy_id" + + mock_client = MagicMock() + mock_boto_client.return_value = mock_client + + retrieval_setting = {"top_k": 4, "score_threshold": 0.5} + query = "test query" + knowledge_id = "kb-123" + + # Mock successful response + mock_client.retrieve.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "retrievalResults": [ + { + "score": 0.9, + "metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc1.pdf"}, + "content": {"text": "content from doc1"}, + }, + { + "score": 0.4, # Below threshold + "metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc2.pdf"}, + "content": {"text": "content from doc2"}, + }, + ], + } + + # Act + result = cast( + dict[str, Any], ExternalDatasetTestService.knowledge_retrieval(retrieval_setting, query, knowledge_id) + ) + + # Assert + assert len(result["records"]) == 1 + record = result["records"][0] + assert record["score"] == 0.9 + assert record["title"] == "s3://bucket/doc1.pdf" + assert record["content"] == "content from doc1" + + # verify retrieve called correctly + mock_client.retrieve.assert_called_once_with( + knowledgeBaseId=knowledge_id, + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 4, + "overrideSearchType": "HYBRID", + } + }, + retrievalQuery={"text": query}, + ) + + # NEW: verify boto3.client created with proper service name and config values + mock_boto_client.assert_called_once_with( + "bedrock-agent-runtime", + aws_secret_access_key="dummy_secret", + aws_access_key_id="dummy_id", + region_name="us-east-1", + ) + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_return_empty_when_no_results(self, mock_boto: MagicMock): + """Test that knowledge_retrieval returns empty records when Bedrock returns nothing""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []} + + # Act + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert result["records"] == [] + + # ===== Error Handling Tests ===== + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_return_empty_on_http_error(self, mock_boto: MagicMock): + """Test that knowledge_retrieval returns empty records if Bedrock returns non-200 status""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}} + + # Act + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert result["records"] == [] + + def test_knowledge_retrieval_should_raise_when_boto_client_creation_fails(self): + """Test that exceptions from boto3.client propagate (e.g., network/credentials issues)""" + with patch("services.knowledge_service.boto3.client") as mock_boto: + mock_boto.side_effect = Exception("client init failed") + with pytest.raises(Exception) as exc_info: + ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb") + assert "client init failed" in str(exc_info.value) + + # ===== Edge Cases ===== + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_handle_missing_threshold_in_settings(self, mock_boto: MagicMock): + """Test that knowledge_retrieval uses 0.0 as default threshold if not provided""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "retrievalResults": [ + { + "score": 0.1, + "metadata": {"x-amz-bedrock-kb-source-uri": "uri"}, + "content": {"text": "text"}, + } + ], + } + + # Act + # retrieval_setting missing "score_threshold" + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert len(result["records"]) == 1 + assert result["records"][0]["score"] == 0.1 diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 3c38888753..4b8bdde46b 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -5,8 +5,13 @@ import pytest from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.model import App, AppMode, EndUser, Message -from services.errors.message import FirstMessageNotExistsError, LastMessageNotExistsError -from services.message_service import MessageService +from services.errors.message import ( + FirstMessageNotExistsError, + LastMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.message_service import MessageService, attach_message_extra_contents class TestMessageServiceFactory: @@ -244,14 +249,12 @@ class TestMessageServicePaginationByFirstId: mock_query_first = MagicMock() mock_query_history = MagicMock() + query_calls = [] + def query_side_effect(*args): if args[0] == Message: - # First call returns mock for first_message query - if not hasattr(query_side_effect, "call_count"): - query_side_effect.call_count = 0 - query_side_effect.call_count += 1 - - if query_side_effect.call_count == 1: + query_calls.append(args) + if len(query_calls) == 1: return mock_query_first else: return mock_query_history @@ -647,3 +650,410 @@ class TestMessageServicePaginationByLastId: assert len(result.data) == 10 # Last message trimmed assert result.has_more is True assert result.limit == 10 + + +class TestMessageServiceUtilities: + """Unit tests for MessageService module-level utility functions.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 16: attach_message_extra_contents with empty list + def test_attach_message_extra_contents_empty(self): + """Test attach_message_extra_contents with empty list does nothing.""" + # Act & Assert (should not raise error) + attach_message_extra_contents([]) + + # Test 17: attach_message_extra_contents with messages + @patch("services.message_service._create_execution_extra_content_repository") + def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory): + """Test attach_message_extra_contents correctly attaches content.""" + # Arrange + messages = [factory.create_message_mock(message_id="msg-1"), factory.create_message_mock(message_id="msg-2")] + + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + + # Mock extra content models + mock_content1 = MagicMock() + mock_content1.model_dump.return_value = {"key": "value1"} + mock_content2 = MagicMock() + mock_content2.model_dump.return_value = {"key": "value2"} + + mock_repo.get_by_message_ids.return_value = [[mock_content1], [mock_content2]] + + # Act + attach_message_extra_contents(messages) + + # Assert + mock_repo.get_by_message_ids.assert_called_once_with(["msg-1", "msg-2"]) + messages[0].set_extra_contents.assert_called_once_with([{"key": "value1"}]) + messages[1].set_extra_contents.assert_called_once_with([{"key": "value2"}]) + + # Test 18: attach_message_extra_contents with index out of bounds + @patch("services.message_service._create_execution_extra_content_repository") + def test_attach_message_extra_contents_index_out_of_bounds(self, mock_create_repo, factory): + """Test attach_message_extra_contents handles missing content lists.""" + # Arrange + messages = [factory.create_message_mock(message_id="msg-1")] + + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + mock_repo.get_by_message_ids.return_value = [] # Empty returned list + + # Act + attach_message_extra_contents(messages) + + # Assert + messages[0].set_extra_contents.assert_called_once_with([]) + + # Test 19: _create_execution_extra_content_repository + @patch("services.message_service.db") + @patch("services.message_service.sessionmaker") + @patch("services.message_service.SQLAlchemyExecutionExtraContentRepository") + def test_create_execution_extra_content_repository(self, mock_repo_class, mock_sessionmaker, mock_db): + """Test _create_execution_extra_content_repository creates expected repository.""" + from services.message_service import _create_execution_extra_content_repository + + # Act + _create_execution_extra_content_repository() + + # Assert + mock_sessionmaker.assert_called_once() + mock_repo_class.assert_called_once() + + +class TestMessageServiceGetMessage: + """Unit tests for MessageService.get_message method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 20: get_message success for EndUser + @patch("services.message_service.db") + def test_get_message_end_user_success(self, mock_db, factory): + """Test get_message returns message for EndUser.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock(user_id="end-user-123") + message = factory.create_message_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = message + + # Act + result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + # Assert + assert result == message + mock_query.where.assert_called_once() + + # Test 21: get_message success for Account (Admin) + @patch("services.message_service.db") + def test_get_message_account_success(self, mock_db, factory): + """Test get_message returns message for Account.""" + # Arrange + from models import Account + + app = factory.create_app_mock() + user = MagicMock(spec=Account) + user.id = "account-123" + message = factory.create_message_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = message + + # Act + result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + # Assert + assert result == message + + # Test 22: get_message not found + @patch("services.message_service.db") + def test_get_message_not_found(self, mock_db, factory): + """Test get_message raises MessageNotExistsError when not found.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act & Assert + with pytest.raises(MessageNotExistsError): + MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + +class TestMessageServiceFeedback: + """Unit tests for MessageService feedback-related methods.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 23: create_feedback - new feedback for EndUser + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory): + """Test creating new feedback for an end user.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + message = factory.create_message_mock() + message.user_feedback = None + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating="like", + content="Good answer", + ) + + # Assert + assert result.rating == "like" + assert result.content == "Good answer" + assert result.from_source == "user" + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + # Test 24: create_feedback - update feedback for Account + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_update_account(self, mock_get_message, mock_db, factory): + """Test updating existing feedback for an account.""" + # Arrange + from models import Account, MessageFeedback + + app = factory.create_app_mock() + user = MagicMock(spec=Account) + user.id = "account-123" + message = factory.create_message_mock() + feedback = MagicMock(spec=MessageFeedback) + message.admin_feedback = feedback + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating="dislike", + content="Bad answer", + ) + + # Assert + assert result == feedback + assert feedback.rating == "dislike" + assert feedback.content == "Bad answer" + mock_db.session.commit.assert_called_once() + + # Test 25: create_feedback - delete feedback (rating is None) + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_delete(self, mock_get_message, mock_db, factory): + """Test deleting feedback by passing rating=None.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + message = factory.create_message_mock() + feedback = MagicMock() + message.user_feedback = feedback + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating=None, + content=None, + ) + + # Assert + assert result == feedback + mock_db.session.delete.assert_called_once_with(feedback) + mock_db.session.commit.assert_called_once() + + # Test 26: get_all_messages_feedbacks + @patch("services.message_service.db") + def test_get_all_messages_feedbacks(self, mock_db, factory): + """Test get_all_messages_feedbacks returns list of dicts.""" + # Arrange + app = factory.create_app_mock() + feedback = MagicMock() + feedback.to_dict.return_value = {"id": "fb-1"} + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.limit.return_value = mock_query + mock_query.offset.return_value = mock_query + mock_query.all.return_value = [feedback] + + # Act + result = MessageService.get_all_messages_feedbacks(app_model=app, page=1, limit=10) + + # Assert + assert result == [{"id": "fb-1"}] + mock_query.limit.assert_called_with(10) + mock_query.offset.assert_called_with(0) + + +class TestMessageServiceSuggestedQuestions: + """Unit tests for MessageService.get_suggested_questions_after_answer method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 27: get_suggested_questions_after_answer - user is None + def test_get_suggested_questions_user_none(self, factory): + app = factory.create_app_mock() + with pytest.raises(ValueError, match="user cannot be None"): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=None, message_id="msg-123", invoke_from=MagicMock() + ) + + # Test 28: get_suggested_questions_after_answer - Advanced Chat success + @patch("services.message_service.ModelManager") + @patch("services.message_service.WorkflowService") + @patch("services.message_service.AdvancedChatAppConfigManager") + @patch("services.message_service.TokenBufferMemory") + @patch("services.message_service.LLMGenerator") + @patch("services.message_service.TraceQueueManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_advanced_chat_success( + self, + mock_conversation_service, + mock_get_message, + mock_trace_manager, + mock_llm_gen, + mock_memory, + mock_config_manager, + mock_workflow_service, + mock_model_manager, + factory, + ): + """Test successful suggested questions generation in Advanced Chat mode.""" + from core.app.entities.app_invoke_entities import InvokeFrom + + # Arrange + app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value) + user = factory.create_end_user_mock() + message = factory.create_message_mock() + mock_get_message.return_value = message + + workflow = MagicMock() + mock_workflow_service.return_value.get_published_workflow.return_value = workflow + + app_config = MagicMock() + app_config.additional_features.suggested_questions_after_answer = True + mock_config_manager.get_app_config.return_value = app_config + + mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] + + # Act + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=InvokeFrom.WEB_APP + ) + + # Assert + assert result == ["Q1?"] + mock_workflow_service.return_value.get_published_workflow.assert_called_once() + mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() + + # Test 29: get_suggested_questions_after_answer - Chat app success (no override) + @patch("services.message_service.db") + @patch("services.message_service.ModelManager") + @patch("services.message_service.TokenBufferMemory") + @patch("services.message_service.LLMGenerator") + @patch("services.message_service.TraceQueueManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_chat_app_success( + self, + mock_conversation_service, + mock_get_message, + mock_trace_manager, + mock_llm_gen, + mock_memory, + mock_model_manager, + mock_db, + factory, + ): + """Test successful suggested questions generation in basic Chat mode.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.CHAT.value) + user = factory.create_end_user_mock() + message = factory.create_message_mock() + mock_get_message.return_value = message + + conversation = MagicMock() + conversation.override_model_configs = None + mock_conversation_service.get_conversation.return_value = conversation + + app_model_config = MagicMock() + app_model_config.suggested_questions_after_answer_dict = {"enabled": True} + app_model_config.model_dict = {"provider": "openai", "name": "gpt-4"} + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = app_model_config + + mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] + + # Act + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock() + ) + + # Assert + assert result == ["Q1?"] + mock_query.first.assert_called_once() + mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() + + # Test 30: get_suggested_questions_after_answer - Disabled Error + @patch("services.message_service.WorkflowService") + @patch("services.message_service.AdvancedChatAppConfigManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_disabled_error( + self, mock_conversation_service, mock_get_message, mock_config_manager, mock_workflow_service, factory + ): + """Test SuggestedQuestionsAfterAnswerDisabledError is raised when feature is disabled.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value) + user = factory.create_end_user_mock() + mock_get_message.return_value = factory.create_message_mock() + + workflow = MagicMock() + mock_workflow_service.return_value.get_published_workflow.return_value = workflow + + app_config = MagicMock() + app_config.additional_features.suggested_questions_after_answer = False + mock_config_manager.get_app_config.return_value = app_config + + # Act & Assert + with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock() + ) diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py index 67ae2c9142..4449b442d6 100644 --- a/api/tests/unit_tests/services/test_messages_clean_service.py +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -554,11 +554,9 @@ class TestMessagesCleanServiceFromDays: MessagesCleanService.from_days(policy=policy, days=-1) # Act - with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) - mock_datetime.datetime.now.return_value = fixed_now - mock_datetime.timedelta = datetime.timedelta - + mock_now.return_value = fixed_now service = MessagesCleanService.from_days(policy=policy, days=0) # Assert @@ -586,11 +584,9 @@ class TestMessagesCleanServiceFromDays: dry_run = True # Act - with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) - mock_datetime.datetime.now.return_value = fixed_now - mock_datetime.timedelta = datetime.timedelta - + mock_now.return_value = fixed_now service = MessagesCleanService.from_days( policy=policy, days=days, @@ -613,11 +609,9 @@ class TestMessagesCleanServiceFromDays: policy = BillingDisabledPolicy() # Act - with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) - mock_datetime.datetime.now.return_value = fixed_now - mock_datetime.timedelta = datetime.timedelta - + mock_now.return_value = fixed_now service = MessagesCleanService.from_days(policy=policy) # Assert diff --git a/api/tests/unit_tests/services/test_operation_service.py b/api/tests/unit_tests/services/test_operation_service.py new file mode 100644 index 0000000000..a4c69b23ac --- /dev/null +++ b/api/tests/unit_tests/services/test_operation_service.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from services.operation_service import OperationService + + +class TestOperationService: + """Test suite for OperationService""" + + # ===== Internal Method Tests ===== + + @patch("httpx.request") + def test_should_call_with_correct_parameters_when__send_request_invoked( + self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch + ): + """Test that _send_request calls httpx.request with the correct URL, headers, and data""" + # Arrange + monkeypatch.setattr(OperationService, "base_url", "https://billing.example") + monkeypatch.setattr(OperationService, "secret_key", "s3cr3t") + + mock_response = MagicMock() + mock_response.json.return_value = {"status": "success"} + mock_request.return_value = mock_response + + method = "POST" + endpoint = "/test_endpoint" + json_data = {"key": "value"} + + # Act + result = OperationService._send_request(method, endpoint, json=json_data) + + # Assert + assert result == {"status": "success"} + + # Verify call parameters + expected_url = "https://billing.example/test_endpoint" + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert args[0] == method + assert args[1] == expected_url + assert kwargs["json"] == json_data + assert kwargs["headers"]["Billing-Api-Secret-Key"] == "s3cr3t" + assert kwargs["headers"]["Content-Type"] == "application/json" + + @patch("httpx.request") + def test_should_propagate_httpx_error_when__send_request_raises( + self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch + ): + """Test that _send_request handles httpx raising an error""" + # Arrange + monkeypatch.setattr(OperationService, "base_url", "https://billing.example") + mock_request.side_effect = httpx.RequestError("network error") + + # Act & Assert + with pytest.raises(httpx.RequestError): + OperationService._send_request("POST", "/test") + + # ===== Public Method Tests ===== + + @pytest.mark.parametrize( + ("utm_info", "expected_params"), + [ + ( + { + "utm_source": "google", + "utm_medium": "cpc", + "utm_campaign": "spring_sale", + "utm_content": "ad_1", + "utm_term": "ai_agent", + }, + { + "tenant_id": "tenant-123", + "utm_source": "google", + "utm_medium": "cpc", + "utm_campaign": "spring_sale", + "utm_content": "ad_1", + "utm_term": "ai_agent", + }, + ), + ( + {}, # Empty utm_info + { + "tenant_id": "tenant-123", + "utm_source": "", + "utm_medium": "", + "utm_campaign": "", + "utm_content": "", + "utm_term": "", + }, + ), + ( + {"utm_source": "newsletter"}, # Partial utm_info + { + "tenant_id": "tenant-123", + "utm_source": "newsletter", + "utm_medium": "", + "utm_campaign": "", + "utm_content": "", + "utm_term": "", + }, + ), + ], + ) + @patch.object(OperationService, "_send_request") + def test_should_map_parameters_correctly_when_record_utm_called( + self, mock_send: MagicMock, utm_info: dict, expected_params: dict + ): + """Test that record_utm correctly maps utm_info to parameters and calls _send_request""" + # Arrange + tenant_id = "tenant-123" + mock_send.return_value = {"status": "recorded"} + + # Act + result = OperationService.record_utm(tenant_id, utm_info) + + # Assert + assert result == {"status": "recorded"} + mock_send.assert_called_once_with("POST", "/tenant_utms", params=expected_params) diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py index 7511fd6f0c..9537d207f0 100644 --- a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py +++ b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py @@ -7,7 +7,7 @@ import pytest from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -175,6 +175,137 @@ class TestMCPToolTransform: # The actual parameter conversion is handled by convert_mcp_schema_to_parameter # which should be tested separately + def test_convert_mcp_schema_to_parameter_preserves_anyof_object_type(self): + """Nullable object schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "anyOf": [{"type": "object"}, {"type": "null"}], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_preserves_oneof_object_type(self): + """Nullable oneOf object schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "oneOf": [{"type": "object"}, {"type": "null"}], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_handles_null_type(self): + """Schemas with only a null type should fall back to string.""" + schema = { + "type": "object", + "properties": { + "null_prop_str": {"type": "null"}, + "null_prop_list": {"type": ["null"]}, + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 2 + param_map = {parameter.name: parameter for parameter in result} + assert "null_prop_str" in param_map + assert param_map["null_prop_str"].type == ToolParameter.ToolParameterType.STRING + assert "null_prop_list" in param_map + assert param_map["null_prop_list"].type == ToolParameter.ToolParameterType.STRING + + def test_convert_mcp_schema_to_parameter_preserves_allof_object_type_with_multiple_object_items(self): + """Property-level allOf with multiple object items should still resolve to object.""" + schema = { + "type": "object", + "properties": { + "config": { + "allOf": [ + { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + }, + "required": ["enabled"], + }, + { + "type": "object", + "properties": { + "priority": {"type": "integer", "minimum": 1, "maximum": 10}, + }, + "required": ["priority"], + }, + ], + "description": "Config must match all schemas (allOf)", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "config" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["config"] + + def test_convert_mcp_schema_to_parameter_preserves_allof_object_type(self): + """Composed property schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "allOf": [ + {"type": "object"}, + {"properties": {"top_k": {"type": "integer"}}}, + ], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_limits_recursive_schema_depth(self): + """Self-referential composed schemas should stop resolving after the configured max depth.""" + recursive_property: dict[str, object] = {"description": "Recursive schema"} + recursive_property["anyOf"] = [recursive_property] + schema = { + "type": "object", + "properties": { + "recursive_config": recursive_property, + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "recursive_config" + assert result[0].type == ToolParameter.ToolParameterType.STRING + assert result[0].input_schema is None + def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full): """Test mcp_provider_to_user_provider with for_list=True.""" # Set tools data with null description diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 11b4663187..67e0a8efaf 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -10,14 +10,23 @@ This module tests the document indexing task functionality including: """ import uuid -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from core.indexing_runner import DocumentIsPausedError from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from models.dataset import Dataset, Document from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy +from tasks.document_indexing_task import ( + _document_indexing, + _document_indexing_with_tenant_queue, + document_indexing_task, + normal_document_indexing_task, + priority_document_indexing_task, +) # ============================================================================ # Fixtures @@ -56,6 +65,190 @@ def mock_redis(): return redis_client +# Additional fixtures required by tests in this module + + +@pytest.fixture +def mock_db_session(): + """Mock session_factory.create_session() to return a session whose queries use shared test data. + + Tests set session._shared_data = {"dataset": , "documents": [, ...]} + This fixture makes session.query(Dataset).first() return the shared dataset, + and session.query(Document).all()/first() return from the shared documents. + """ + with patch("tasks.document_indexing_task.session_factory") as mock_sf: + session = MagicMock() + session._shared_data = {"dataset": None, "documents": []} + + # Keep a pointer so repeated Document.first() calls iterate across provided docs + session._doc_first_idx = 0 + + def _query_side_effect(model): + q = MagicMock() + + # Capture filters passed via where(...) so first()/all() can honor them. + q._filters = {} + + def _extract_filters(*conds, **kw): + # Support both SQLAlchemy expressions (BinaryExpression) and kwargs + # We only need the simple fields used by production code: id, dataset_id, and id.in_(...) + for cond in conds: + left = getattr(cond, "left", None) + right = getattr(cond, "right", None) + key = None + if left is not None: + key = getattr(left, "key", None) or getattr(left, "name", None) + if not key: + continue + # Right side might be a BindParameter with .value, or a raw value/sequence + val = getattr(right, "value", right) + q._filters[key] = val + # Also accept kwargs (e.g., where(id=...)) just in case + for k, v in kw.items(): + q._filters[k] = v + + def _where_side_effect(*conds, **kw): + _extract_filters(*conds, **kw) + return q + + q.where.side_effect = _where_side_effect + + # Dataset queries + if model.__name__ == "Dataset": + + def _dataset_first(): + ds = session._shared_data.get("dataset") + if not ds: + return None + if "id" in q._filters: + val = q._filters["id"] + if isinstance(val, (list, tuple, set)): + return ds if ds.id in val else None + return ds if ds.id == val else None + return ds + + def _dataset_all(): + ds = session._shared_data.get("dataset") + if not ds: + return [] + first = _dataset_first() + return [first] if first else [] + + q.first.side_effect = _dataset_first + q.all.side_effect = _dataset_all + return q + + # Document queries + if model.__name__ == "Document": + + def _apply_doc_filters(docs): + result = list(docs) + for key in ("id", "dataset_id"): + if key in q._filters: + val = q._filters[key] + if isinstance(val, (list, tuple, set)): + result = [d for d in result if getattr(d, key, None) in val] + else: + result = [d for d in result if getattr(d, key, None) == val] + return result + + def _docs_all(): + docs = session._shared_data.get("documents", []) + return _apply_doc_filters(docs) + + def _docs_first(): + docs = _docs_all() + return docs[0] if docs else None + + q.all.side_effect = _docs_all + q.first.side_effect = _docs_first + return q + + # Default fallback + q.first.return_value = None + q.all.return_value = [] + return q + + session.query.side_effect = _query_side_effect + + # Implement session.begin() context manager that commits on exit + session.commit = MagicMock() + bm = MagicMock() + bm.__enter__.return_value = session + + def _bm_exit_side_effect(*args, **kwargs): + session.commit() + + bm.__exit__.side_effect = _bm_exit_side_effect + session.begin.return_value = bm + + # Context manager behavior for create_session(): ensure close() is called on exit + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + yield session + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_documents(document_ids, dataset_id): + """Create mock Document objects.""" + documents = [] + for doc_id in document_ids: + doc = Mock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + # optional attribute used in some code paths + doc.doc_form = "text_model" + documents.append(doc) + return documents + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner for document_indexing_task module.""" + with patch("tasks.document_indexing_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock() + mock_runner_class.return_value = mock_runner + yield mock_runner + + +@pytest.fixture +def mock_feature_service(): + """Mock FeatureService for document_indexing_task module.""" + with patch("tasks.document_indexing_task.FeatureService") as mock_service: + mock_features = Mock() + mock_features.billing = Mock() + mock_features.billing.enabled = False + mock_features.vector_space = Mock() + mock_features.vector_space.size = 0 + mock_features.vector_space.limit = 1000 + mock_service.get_features.return_value = mock_features + yield mock_service + + # ============================================================================ # Test Task Enqueuing # ============================================================================ @@ -166,6 +359,492 @@ class TestTaskEnqueuing: assert mock_redis.lpush.called mock_task.delay.assert_not_called() + def test_legacy_document_indexing_task_still_works( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_documents, mock_indexing_runner + ): + """ + Test that the legacy document_indexing_task function still works. + + This ensures backward compatibility for existing code that may still + use the deprecated function. + """ + # Arrange + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + document_indexing_task(dataset_id, document_ids) + + # Assert + mock_indexing_runner.run.assert_called_once() + + +# ============================================================================ +# Test Batch Processing +# ============================================================================ + + +class TestBatchProcessing: + """Test cases for batch processing of multiple documents.""" + + def test_batch_processing_multiple_documents( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test batch processing of multiple documents. + + All documents in the batch should be processed together and their + status should be updated to 'parsing'. + """ + # Arrange - Create actual document objects that can be modified + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should be set to 'parsing' status + for doc in mock_documents: + assert doc.indexing_status == "parsing" + assert doc.processing_started_at is not None + + # IndexingRunner should be called with all documents + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == len(document_ids) + + def test_batch_processing_with_limit_check(self, dataset_id, mock_db_session, mock_dataset, mock_feature_service): + """ + Test batch processing respects upload limits. + + When the number of documents exceeds the batch upload limit, + an error should be raised and all documents should be marked as error. + """ + # Arrange + batch_limit = 10 + document_ids = [str(uuid.uuid4()) for _ in range(batch_limit + 1)] + + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 1000 + mock_feature_service.get_features.return_value.vector_space.size = 0 + + with patch("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", str(batch_limit)): + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should have error status + for doc in mock_documents: + assert doc.indexing_status == "error" + assert doc.error is not None + assert "batch upload limit" in doc.error + + def test_batch_processing_sandbox_plan_single_document_only( + self, dataset_id, mock_db_session, mock_dataset, mock_feature_service + ): + """ + Test that sandbox plan only allows single document upload. + + Sandbox plan should reject batch uploads (more than 1 document). + """ + # Arrange + document_ids = [str(uuid.uuid4()) for _ in range(2)] + + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX + mock_feature_service.get_features.return_value.vector_space.limit = 1000 + mock_feature_service.get_features.return_value.vector_space.size = 0 + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should have error status + for doc in mock_documents: + assert doc.indexing_status == "error" + assert "does not support batch upload" in doc.error + + def test_batch_processing_empty_document_list( + self, dataset_id, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test batch processing with empty document list. + + Should handle empty list gracefully without errors. + """ + # Arrange + document_ids = [] + + # Set shared mock data with empty documents list + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [] + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - IndexingRunner should still be called with empty list + mock_indexing_runner.run.assert_called_once_with([]) + + +# ============================================================================ +# Test Progress Tracking +# ============================================================================ + + +class TestProgressTracking: + """Test cases for progress tracking through task lifecycle.""" + + def test_document_status_progression( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test document status progresses correctly through lifecycle. + + Documents should transition from 'waiting' -> 'parsing' -> processed. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - Status should be 'parsing' + for doc in mock_documents: + assert doc.indexing_status == "parsing" + assert doc.processing_started_at is not None + + # Verify commit was called to persist status + assert mock_db_session.commit.called + + def test_processing_started_timestamp_set( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that processing_started_at timestamp is set correctly. + + When documents start processing, the timestamp should be recorded. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert + for doc in mock_documents: + assert doc.processing_started_at is not None + + def test_tenant_queue_processes_next_task_after_completion( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that tenant queue processes next waiting task after completion. + + After a task completes, the system should check for waiting tasks + and process the next one. + """ + # Arrange + next_task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["next_doc_id"]} + + # Simulate next task in queue + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=next_task_data) + mock_redis.rpop.return_value = wrapper.serialize() + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Next task should be enqueued + mock_task.apply_async.assert_called() + # Task key should be set for next task + assert mock_redis.setex.called + + def test_tenant_queue_clears_flag_when_no_more_tasks( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that tenant queue clears flag when no more tasks are waiting. + + When there are no more tasks in the queue, the task key should be deleted. + """ + # Arrange + mock_redis.rpop.return_value = None # No more tasks + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Task key should be deleted + assert mock_redis.delete.called + + +# ============================================================================ +# Test Error Handling and Retries +# ============================================================================ + + +class TestErrorHandling: + """Test cases for error handling and retry mechanisms.""" + + def test_error_handling_sets_document_error_status( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_feature_service + ): + """ + Test that errors during validation set document error status. + + When validation fails (e.g., limit exceeded), documents should be + marked with error status and error message. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Set up to trigger vector space limit error + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 100 + mock_feature_service.get_features.return_value.vector_space.size = 100 # At limit + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert + for doc in mock_documents: + assert doc.indexing_status == "error" + assert doc.error is not None + assert "over the limit" in doc.error + assert doc.stopped_at is not None + + def test_error_handling_during_indexing_runner( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_documents, mock_indexing_runner + ): + """ + Test error handling when IndexingRunner raises an exception. + + Errors during indexing should be caught and logged, but not crash the task. + """ + # Arrange + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Make IndexingRunner raise an exception + mock_indexing_runner.run.side_effect = Exception("Indexing failed") + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act - Should not raise exception + _document_indexing(dataset_id, document_ids) + + # Assert - Session should be closed even after error + assert mock_db_session.close.called + + def test_document_paused_error_handling( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_documents, mock_indexing_runner + ): + """ + Test handling of DocumentIsPausedError. + + When a document is paused, the error should be caught and logged + but not treated as a failure. + """ + # Arrange + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Make IndexingRunner raise DocumentIsPausedError + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act - Should not raise exception + _document_indexing(dataset_id, document_ids) + + # Assert - Session should be closed + assert mock_db_session.close.called + + def test_dataset_not_found_error_handling(self, dataset_id, document_ids, mock_db_session): + """ + Test handling when dataset is not found. + + If the dataset doesn't exist, the task should exit gracefully. + """ + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - Session should be closed + assert mock_db_session.close.called + + def test_tenant_queue_error_handling_still_processes_next_task( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that errors don't prevent processing next task in tenant queue. + + Even if the current task fails, the next task should still be processed. + """ + # Arrange + next_task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["next_doc_id"]} + + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=next_task_data) + # Set up rpop to return task once for concurrency check + mock_redis.rpop.side_effect = [wrapper.serialize(), None] + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + # Make _document_indexing raise an error + with patch("tasks.document_indexing_task._document_indexing") as mock_indexing: + mock_indexing.side_effect = Exception("Processing failed") + + # Patch logger to avoid format string issue in actual code + with patch("tasks.document_indexing_task.logger"): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Next task should still be enqueued despite error + mock_task.apply_async.assert_called() + + def test_concurrent_task_limit_respected( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset + ): + """ + Test that tenant isolated task concurrency limit is respected. + + Should pull only TENANT_ISOLATED_TASK_CONCURRENCY tasks at a time. + """ + # Arrange + concurrency_limit = 2 + + # Create multiple tasks in queue + tasks = [] + for i in range(5): + task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": [f"doc_{i}"]} + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks one by one + mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None] + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Should enqueue exactly concurrency_limit tasks + assert mock_task.apply_async.call_count == concurrency_limit + # ============================================================================ # Test Task Cancellation @@ -198,6 +877,407 @@ class TestTaskCancellation: assert tenant_2 in queue_2._queue +# ============================================================================ +# Integration Tests +# ============================================================================ + + +class TestAdvancedScenarios: + """Advanced test scenarios for edge cases and complex workflows.""" + + def test_multiple_documents_with_mixed_success_and_failure( + self, dataset_id, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test handling of mixed success and failure scenarios in batch processing. + + When processing multiple documents, some may succeed while others fail. + This tests that the system handles partial failures gracefully. + + Scenario: + - Process 3 documents in a batch + - First document succeeds + - Second document is not found (skipped) + - Third document succeeds + + Expected behavior: + - Only found documents are processed + - Missing documents are skipped without crashing + - IndexingRunner receives only valid documents + """ + # Arrange - Create document IDs with one missing + document_ids = [str(uuid.uuid4()) for _ in range(3)] + + # Create only 2 documents (simulate one missing) + # The new code uses .all() which will only return existing documents + mock_documents = [] + for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data - .all() will only return existing documents + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - Only 2 documents should be processed (missing one skipped) + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == 2 # Only found documents + + def test_tenant_queue_with_multiple_concurrent_tasks( + self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset + ): + """ + Test concurrent task processing with tenant isolation. + + This tests the scenario where multiple tasks are queued for the same tenant + and need to be processed respecting the concurrency limit. + + Scenario: + - 5 tasks are waiting in the queue + - Concurrency limit is 2 + - After current task completes, pull and enqueue next 2 tasks + + Expected behavior: + - Exactly 2 tasks are pulled from queue (respecting concurrency) + - Each task is enqueued with correct parameters + - Task waiting time is set for each new task + """ + # Arrange + concurrency_limit = 2 + document_ids = [str(uuid.uuid4())] + + # Create multiple waiting tasks + waiting_tasks = [] + for i in range(5): + task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": [f"doc_{i}"]} + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + waiting_tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks up to concurrency limit + mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert + # Should enqueue exactly concurrency_limit tasks + assert mock_task.apply_async.call_count == concurrency_limit + + # Verify task waiting time was set for each task + assert mock_redis.setex.call_count >= concurrency_limit + + def test_vector_space_limit_edge_case_at_exact_limit( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_feature_service + ): + """ + Test vector space limit validation at exact boundary. + + Edge case: When vector space is exactly at the limit (not over), + the upload should still be rejected. + + Scenario: + - Vector space limit: 100 + - Current size: 100 (exactly at limit) + - Try to upload 3 documents + + Expected behavior: + - Upload is rejected with appropriate error message + - All documents are marked with error status + """ + # Arrange + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Set vector space exactly at limit + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 100 + mock_feature_service.get_features.return_value.vector_space.size = 100 # Exactly at limit + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should have error status + for doc in mock_documents: + assert doc.indexing_status == "error" + assert "over the limit" in doc.error + + def test_task_queue_fifo_ordering(self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset): + """ + Test that tasks are processed in FIFO (First-In-First-Out) order. + + The tenant isolated queue should maintain task order, ensuring + that tasks are processed in the sequence they were added. + + Scenario: + - Task A added first + - Task B added second + - Task C added third + - When pulling tasks, should get A, then B, then C + + Expected behavior: + - Tasks are retrieved in the order they were added + - FIFO ordering is maintained throughout processing + """ + # Arrange + document_ids = [str(uuid.uuid4())] + + # Create tasks with identifiable document IDs to track order + task_order = ["task_A", "task_B", "task_C"] + tasks = [] + for task_name in task_order: + task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": [task_name]} + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks in FIFO order + mock_redis.rpop.side_effect = tasks + [None] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Verify tasks were enqueued in correct order + assert mock_task.apply_async.call_count == 3 + + # Check that document_ids in calls match expected order + for i, call_obj in enumerate(mock_task.apply_async.call_args_list): + called_doc_ids = call_obj[1]["kwargs"]["document_ids"] + assert called_doc_ids == [task_order[i]] + + def test_empty_queue_after_task_completion_cleans_up( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset + ): + """ + Test cleanup behavior when queue becomes empty after task completion. + + After processing the last task in the queue, the system should: + 1. Detect that no more tasks are waiting + 2. Delete the task key to indicate tenant is idle + 3. Allow new tasks to start fresh processing + + Scenario: + - Process a task + - Check queue for next tasks + - Queue is empty + - Task key should be deleted + + Expected behavior: + - Task key is deleted when queue is empty + - Tenant is marked as idle (no active tasks) + """ + # Arrange + mock_redis.rpop.return_value = None # Empty queue + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert + # Verify delete was called to clean up task key + mock_redis.delete.assert_called_once() + + # Verify the correct key was deleted (contains tenant_id and "document_indexing") + delete_call_args = mock_redis.delete.call_args[0][0] + assert tenant_id in delete_call_args + assert "document_indexing" in delete_call_args + + def test_billing_disabled_skips_limit_checks( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service + ): + """ + Test that billing limit checks are skipped when billing is disabled. + + For self-hosted or enterprise deployments where billing is disabled, + the system should not enforce vector space or batch upload limits. + + Scenario: + - Billing is disabled + - Upload 100 documents (would normally exceed limits) + - No limit checks should be performed + + Expected behavior: + - Documents are processed without limit validation + - No errors related to limits + - All documents proceed to indexing + """ + # Arrange - Create many documents + large_batch_ids = [str(uuid.uuid4()) for _ in range(100)] + + mock_documents = [] + for doc_id in large_batch_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Billing disabled - limits should not be checked + mock_feature_service.get_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, large_batch_ids) + + # Assert + # All documents should be set to parsing (no limit errors) + for doc in mock_documents: + assert doc.indexing_status == "parsing" + + # IndexingRunner should be called with all documents + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == 100 + + +class TestIntegration: + """Integration tests for complete task workflows.""" + + def test_complete_workflow_normal_task( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test complete workflow for normal document indexing task. + + This tests the full flow from task receipt to completion. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set up rpop to return None for concurrency check (no more tasks) + mock_redis.rpop.side_effect = [None] + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + normal_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + # Documents should be processed + mock_indexing_runner.run.assert_called_once() + # Session should be closed + assert mock_db_session.close.called + # Task key should be deleted (no more tasks) + assert mock_redis.delete.called + + def test_complete_workflow_priority_task( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test complete workflow for priority document indexing task. + + Priority tasks should follow the same flow as normal tasks. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set up rpop to return None for concurrency check (no more tasks) + mock_redis.rpop.side_effect = [None] + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + priority_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_indexing_runner.run.assert_called_once() + assert mock_db_session.close.called + assert mock_redis.delete.called + + def test_queue_chain_processing( + self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that multiple tasks in queue are processed in sequence. + + When tasks are queued, they should be processed one after another. + """ + # Arrange + task_1_docs = [str(uuid.uuid4())] + task_2_docs = [str(uuid.uuid4())] + + task_2_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": task_2_docs} + + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_2_data) + + # First call returns task 2, second call returns None + mock_redis.rpop.side_effect = [wrapper.serialize(), None] + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act - Process first task + _document_indexing_with_tenant_queue(tenant_id, dataset_id, task_1_docs, mock_task) + + # Assert - Second task should be enqueued + assert mock_task.apply_async.called + call_args = mock_task.apply_async.call_args + assert call_args[1]["kwargs"]["document_ids"] == task_2_docs + + # ============================================================================ # Additional Edge Case Tests # ============================================================================ @@ -249,6 +1329,107 @@ class TestEdgeCases: class TestPerformanceScenarios: """Test performance-related scenarios and optimizations.""" + def test_large_document_batch_processing( + self, dataset_id, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service + ): + """ + Test processing a large batch of documents at batch limit. + + When processing the maximum allowed batch size, the system + should handle it efficiently without errors. + + Scenario: + - Process exactly batch_upload_limit documents (e.g., 50) + - All documents are valid + - Billing is enabled + + Expected behavior: + - All documents are processed successfully + - No timeout or memory issues + - Batch limit is not exceeded + """ + # Arrange + batch_limit = 50 + document_ids = [str(uuid.uuid4()) for _ in range(batch_limit)] + + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Configure billing with sufficient limits + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 10000 + mock_feature_service.get_features.return_value.vector_space.size = 0 + + with patch("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", str(batch_limit)): + # Act + _document_indexing(dataset_id, document_ids) + + # Assert + for doc in mock_documents: + assert doc.indexing_status == "parsing" + + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == batch_limit + + def test_tenant_queue_handles_burst_traffic(self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset): + """ + Test tenant queue handling burst traffic scenarios. + + When many tasks arrive in a burst for the same tenant, + the queue should handle them efficiently without dropping tasks. + + Scenario: + - 20 tasks arrive rapidly + - Concurrency limit is 3 + - Tasks should be queued and processed in batches + + Expected behavior: + - First 3 tasks are processed immediately + - Remaining tasks wait in queue + - No tasks are lost + """ + # Arrange + num_tasks = 20 + concurrency_limit = 3 + document_ids = [str(uuid.uuid4())] + + # Create waiting tasks + waiting_tasks = [] + for i in range(num_tasks): + task_data = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": [f"doc_{i}"], + } + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + waiting_tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks up to concurrency limit + mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Should process exactly concurrency_limit tasks + assert mock_task.apply_async.call_count == concurrency_limit + def test_multiple_tenants_isolated_processing(self, mock_redis): """ Test that multiple tenants process tasks in isolation. diff --git a/api/tests/unit_tests/tasks/test_summary_queue_isolation.py b/api/tests/unit_tests/tasks/test_summary_queue_isolation.py new file mode 100644 index 0000000000..f6632e0a8a --- /dev/null +++ b/api/tests/unit_tests/tasks/test_summary_queue_isolation.py @@ -0,0 +1,40 @@ +""" +Unit tests for summary index task queue isolation. + +These tasks must NOT run on the shared 'dataset' queue because they invoke LLMs +for each document segment and can occupy all worker slots for hours, blocking +document indexing tasks. +""" + +import pytest + +from tasks.generate_summary_index_task import generate_summary_index_task +from tasks.regenerate_summary_index_task import regenerate_summary_index_task + +SUMMARY_QUEUE = "dataset_summary" +INDEXING_QUEUE = "dataset" + + +def _task_queue(task) -> str | None: + # Celery's @shared_task(queue=...) stores the routing key on the task instance + # at runtime, but type stubs don't declare it; use getattr to stay type-clean. + return getattr(task, "queue", None) + + +@pytest.mark.parametrize( + ("task", "task_name"), + [ + (generate_summary_index_task, "generate_summary_index_task"), + (regenerate_summary_index_task, "regenerate_summary_index_task"), + ], +) +def test_summary_task_uses_dedicated_queue(task, task_name): + """Summary tasks must use the dataset_summary queue, not the shared dataset queue. + + Summary generation is LLM-heavy and will block document indexing if placed + on the shared queue. + """ + assert _task_queue(task) == SUMMARY_QUEUE, ( + f"{task_name} must run on '{SUMMARY_QUEUE}' queue (not '{INDEXING_QUEUE}'). " + "Summary generation is LLM-heavy and will block document indexing if placed on the shared queue." + ) diff --git a/api/uv.lock b/api/uv.lock index 5a9ac096dc..6b4fea62a5 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -505,14 +505,14 @@ wheels = [ [[package]] name = "basedpyright" -version = "1.31.7" +version = "1.38.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c6/ba/ed69e8df732a09c8ca469f592c8e08707fe29149735b834c276d94d4a3da/basedpyright-1.31.7.tar.gz", hash = "sha256:394f334c742a19bcc5905b2455c9f5858182866b7679a6f057a70b44b049bceb", size = 22710948, upload-time = "2025-10-11T05:12:48.3Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/a3/20aa7c4e83f2f614e0036300f3c352775dede0655c66814da16c37b661a9/basedpyright-1.38.2.tar.gz", hash = "sha256:b433b2b8ba745ed7520cdc79a29a03682f3fb00346d272ece5944e9e5e5daa92", size = 25277019, upload-time = "2026-02-26T11:18:43.594Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, + { url = "https://files.pythonhosted.org/packages/ac/12/736cab83626fea3fe65cdafb3ef3d2ee9480c56723f2fd33921537289a5e/basedpyright-1.38.2-py3-none-any.whl", hash = "sha256:153481d37fd19f9e3adedc8629d1d071b10c5f5e49321fb026b74444b7c70e24", size = 12312475, upload-time = "2026-02-26T11:18:40.373Z" }, ] [[package]] @@ -1606,7 +1606,7 @@ requires-dist = [ { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "litellm", specifier = "==1.77.1" }, - { name = "markdown", specifier = "~=3.5.1" }, + { name = "markdown", specifier = "~=3.8.1" }, { name = "mlflow-skinny", specifier = ">=3.0.0" }, { name = "numpy", specifier = "~=1.26.4" }, { name = "openpyxl", specifier = "~=3.1.5" }, @@ -1660,7 +1660,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "basedpyright", specifier = "~=1.31.0" }, + { name = "basedpyright", specifier = "~=1.38.2" }, { name = "boto3-stubs", specifier = ">=1.38.20" }, { name = "celery-types", specifier = ">=0.23.0" }, { name = "coverage", specifier = "~=7.2.4" }, @@ -1669,9 +1669,9 @@ dev = [ { name = "hypothesis", specifier = ">=6.131.15" }, { name = "import-linter", specifier = ">=2.3" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, - { name = "mypy", specifier = "~=1.17.1" }, + { name = "mypy", specifier = "~=1.19.1" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, - { name = "pyrefly", specifier = ">=0.54.0" }, + { name = "pyrefly", specifier = ">=0.55.0" }, { name = "pytest", specifier = "~=8.3.2" }, { name = "pytest-benchmark", specifier = "~=4.0.0" }, { name = "pytest-cov", specifier = "~=4.1.0" }, @@ -3267,6 +3267,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/f0/63b06b99b730b9954f8709f6f7d9b8d076fa0a973e472efe278089bde42b/langsmith-0.1.147-py3-none-any.whl", hash = "sha256:7166fc23b965ccf839d64945a78e9f1157757add228b086141eb03a60d699a15", size = 311812, upload-time = "2024-11-27T17:32:39.569Z" }, ] +[[package]] +name = "librt" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/9c/b4b0c54d84da4a94b37bd44151e46d5e583c9534c7e02250b961b1b6d8a8/librt-0.8.1.tar.gz", hash = "sha256:be46a14693955b3bd96014ccbdb8339ee8c9346fbe11c1b78901b55125f14c73", size = 177471, upload-time = "2026-02-17T16:13:06.101Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/01/0e748af5e4fee180cf7cd12bd12b0513ad23b045dccb2a83191bde82d168/librt-0.8.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:681dc2451d6d846794a828c16c22dc452d924e9f700a485b7ecb887a30aad1fd", size = 65315, upload-time = "2026-02-17T16:11:25.152Z" }, + { url = "https://files.pythonhosted.org/packages/9d/4d/7184806efda571887c798d573ca4134c80ac8642dcdd32f12c31b939c595/librt-0.8.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3b4350b13cc0e6f5bec8fa7caf29a8fb8cdc051a3bae45cfbfd7ce64f009965", size = 68021, upload-time = "2026-02-17T16:11:26.129Z" }, + { url = "https://files.pythonhosted.org/packages/ae/88/c3c52d2a5d5101f28d3dc89298444626e7874aa904eed498464c2af17627/librt-0.8.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ac1e7817fd0ed3d14fd7c5df91daed84c48e4c2a11ee99c0547f9f62fdae13da", size = 194500, upload-time = "2026-02-17T16:11:27.177Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5d/6fb0a25b6a8906e85b2c3b87bee1d6ed31510be7605b06772f9374ca5cb3/librt-0.8.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:747328be0c5b7075cde86a0e09d7a9196029800ba75a1689332348e998fb85c0", size = 205622, upload-time = "2026-02-17T16:11:28.242Z" }, + { url = "https://files.pythonhosted.org/packages/b2/a6/8006ae81227105476a45691f5831499e4d936b1c049b0c1feb17c11b02d1/librt-0.8.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0af2bd2bc204fa27f3d6711d0f360e6b8c684a035206257a81673ab924aa11e", size = 218304, upload-time = "2026-02-17T16:11:29.344Z" }, + { url = "https://files.pythonhosted.org/packages/ee/19/60e07886ad16670aae57ef44dada41912c90906a6fe9f2b9abac21374748/librt-0.8.1-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d480de377f5b687b6b1bc0c0407426da556e2a757633cc7e4d2e1a057aa688f3", size = 211493, upload-time = "2026-02-17T16:11:30.445Z" }, + { url = "https://files.pythonhosted.org/packages/9c/cf/f666c89d0e861d05600438213feeb818c7514d3315bae3648b1fc145d2b6/librt-0.8.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d0ee06b5b5291f609ddb37b9750985b27bc567791bc87c76a569b3feed8481ac", size = 219129, upload-time = "2026-02-17T16:11:32.021Z" }, + { url = "https://files.pythonhosted.org/packages/8f/ef/f1bea01e40b4a879364c031476c82a0dc69ce068daad67ab96302fed2d45/librt-0.8.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9e2c6f77b9ad48ce5603b83b7da9ee3e36b3ab425353f695cba13200c5d96596", size = 213113, upload-time = "2026-02-17T16:11:33.192Z" }, + { url = "https://files.pythonhosted.org/packages/9b/80/cdab544370cc6bc1b72ea369525f547a59e6938ef6863a11ab3cd24759af/librt-0.8.1-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:439352ba9373f11cb8e1933da194dcc6206daf779ff8df0ed69c5e39113e6a99", size = 212269, upload-time = "2026-02-17T16:11:34.373Z" }, + { url = "https://files.pythonhosted.org/packages/9d/9c/48d6ed8dac595654f15eceab2035131c136d1ae9a1e3548e777bb6dbb95d/librt-0.8.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:82210adabbc331dbb65d7868b105185464ef13f56f7f76688565ad79f648b0fe", size = 234673, upload-time = "2026-02-17T16:11:36.063Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/35b68b1db517f27a01be4467593292eb5315def8900afad29fabf56304ba/librt-0.8.1-cp311-cp311-win32.whl", hash = "sha256:52c224e14614b750c0a6d97368e16804a98c684657c7518752c356834fff83bb", size = 54597, upload-time = "2026-02-17T16:11:37.544Z" }, + { url = "https://files.pythonhosted.org/packages/71/02/796fe8f02822235966693f257bf2c79f40e11337337a657a8cfebba5febc/librt-0.8.1-cp311-cp311-win_amd64.whl", hash = "sha256:c00e5c884f528c9932d278d5c9cbbea38a6b81eb62c02e06ae53751a83a4d52b", size = 61733, upload-time = "2026-02-17T16:11:38.691Z" }, + { url = "https://files.pythonhosted.org/packages/28/ad/232e13d61f879a42a4e7117d65e4984bb28371a34bb6fb9ca54ec2c8f54e/librt-0.8.1-cp311-cp311-win_arm64.whl", hash = "sha256:f7cdf7f26c2286ffb02e46d7bac56c94655540b26347673bea15fa52a6af17e9", size = 52273, upload-time = "2026-02-17T16:11:40.308Z" }, + { url = "https://files.pythonhosted.org/packages/95/21/d39b0a87ac52fc98f621fb6f8060efb017a767ebbbac2f99fbcbc9ddc0d7/librt-0.8.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a28f2612ab566b17f3698b0da021ff9960610301607c9a5e8eaca62f5e1c350a", size = 66516, upload-time = "2026-02-17T16:11:41.604Z" }, + { url = "https://files.pythonhosted.org/packages/69/f1/46375e71441c43e8ae335905e069f1c54febee63a146278bcee8782c84fd/librt-0.8.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:60a78b694c9aee2a0f1aaeaa7d101cf713e92e8423a941d2897f4fa37908dab9", size = 68634, upload-time = "2026-02-17T16:11:43.268Z" }, + { url = "https://files.pythonhosted.org/packages/0a/33/c510de7f93bf1fa19e13423a606d8189a02624a800710f6e6a0a0f0784b3/librt-0.8.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:758509ea3f1eba2a57558e7e98f4659d0ea7670bff49673b0dde18a3c7e6c0eb", size = 198941, upload-time = "2026-02-17T16:11:44.28Z" }, + { url = "https://files.pythonhosted.org/packages/dd/36/e725903416409a533d92398e88ce665476f275081d0d7d42f9c4951999e5/librt-0.8.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:039b9f2c506bd0ab0f8725aa5ba339c6f0cd19d3b514b50d134789809c24285d", size = 209991, upload-time = "2026-02-17T16:11:45.462Z" }, + { url = "https://files.pythonhosted.org/packages/30/7a/8d908a152e1875c9f8eac96c97a480df425e657cdb47854b9efaa4998889/librt-0.8.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bb54f1205a3a6ab41a6fd71dfcdcbd278670d3a90ca502a30d9da583105b6f7", size = 224476, upload-time = "2026-02-17T16:11:46.542Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b8/a22c34f2c485b8903a06f3fe3315341fe6876ef3599792344669db98fcff/librt-0.8.1-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:05bd41cdee35b0c59c259f870f6da532a2c5ca57db95b5f23689fcb5c9e42440", size = 217518, upload-time = "2026-02-17T16:11:47.746Z" }, + { url = "https://files.pythonhosted.org/packages/79/6f/5c6fea00357e4f82ba44f81dbfb027921f1ab10e320d4a64e1c408d035d9/librt-0.8.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adfab487facf03f0d0857b8710cf82d0704a309d8ffc33b03d9302b4c64e91a9", size = 225116, upload-time = "2026-02-17T16:11:49.298Z" }, + { url = "https://files.pythonhosted.org/packages/f2/a0/95ced4e7b1267fe1e2720a111685bcddf0e781f7e9e0ce59d751c44dcfe5/librt-0.8.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:153188fe98a72f206042be10a2c6026139852805215ed9539186312d50a8e972", size = 217751, upload-time = "2026-02-17T16:11:50.49Z" }, + { url = "https://files.pythonhosted.org/packages/93/c2/0517281cb4d4101c27ab59472924e67f55e375bc46bedae94ac6dc6e1902/librt-0.8.1-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:dd3c41254ee98604b08bd5b3af5bf0a89740d4ee0711de95b65166bf44091921", size = 218378, upload-time = "2026-02-17T16:11:51.783Z" }, + { url = "https://files.pythonhosted.org/packages/43/e8/37b3ac108e8976888e559a7b227d0ceac03c384cfd3e7a1c2ee248dbae79/librt-0.8.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e0d138c7ae532908cbb342162b2611dbd4d90c941cd25ab82084aaf71d2c0bd0", size = 241199, upload-time = "2026-02-17T16:11:53.561Z" }, + { url = "https://files.pythonhosted.org/packages/4b/5b/35812d041c53967fedf551a39399271bbe4257e681236a2cf1a69c8e7fa1/librt-0.8.1-cp312-cp312-win32.whl", hash = "sha256:43353b943613c5d9c49a25aaffdba46f888ec354e71e3529a00cca3f04d66a7a", size = 54917, upload-time = "2026-02-17T16:11:54.758Z" }, + { url = "https://files.pythonhosted.org/packages/de/d1/fa5d5331b862b9775aaf2a100f5ef86854e5d4407f71bddf102f4421e034/librt-0.8.1-cp312-cp312-win_amd64.whl", hash = "sha256:ff8baf1f8d3f4b6b7257fcb75a501f2a5499d0dda57645baa09d4d0d34b19444", size = 62017, upload-time = "2026-02-17T16:11:55.748Z" }, + { url = "https://files.pythonhosted.org/packages/c7/7c/c614252f9acda59b01a66e2ddfd243ed1c7e1deab0293332dfbccf862808/librt-0.8.1-cp312-cp312-win_arm64.whl", hash = "sha256:0f2ae3725904f7377e11cc37722d5d401e8b3d5851fb9273d7f4fe04f6b3d37d", size = 52441, upload-time = "2026-02-17T16:11:56.801Z" }, +] + [[package]] name = "litellm" version = "1.77.1" @@ -3403,11 +3437,11 @@ wheels = [ [[package]] name = "markdown" -version = "3.5.2" +version = "3.8.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/28/c5441a6642681d92de56063fa7984df56f783d3f1eba518dc3e7a253b606/Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8", size = 349398, upload-time = "2024-01-10T15:19:38.261Z" } +sdist = { url = "https://files.pythonhosted.org/packages/db/7c/0738e5ff0adccd0b4e02c66d0446c03a3c557e02bb49b7c263d7ab56c57d/markdown-3.8.1.tar.gz", hash = "sha256:a2e2f01cead4828ee74ecca9623045f62216aef2212a7685d6eb9163f590b8c1", size = 361280, upload-time = "2025-06-18T14:50:49.618Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/f4/f0031854de10a0bc7821ef9fca0b92ca0d7aa6fbfbf504c5473ba825e49c/Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd", size = 103870, upload-time = "2024-01-10T15:19:36.071Z" }, + { url = "https://files.pythonhosted.org/packages/50/34/3d1ff0cb4843a33817d06800e9383a2b2a2df4d508e37f53a40e829905d9/markdown-3.8.1-py3-none-any.whl", hash = "sha256:46cc0c0f1e5211ab2e9d453582f0b28a1bfaf058a9f7d5c50386b99b588d8811", size = 106642, upload-time = "2025-06-18T14:50:48.52Z" }, ] [[package]] @@ -3653,28 +3687,29 @@ wheels = [ [[package]] name = "mypy" -version = "1.17.1" +version = "1.19.1" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "librt", marker = "platform_python_implementation != 'PyPy'" }, { name = "mypy-extensions" }, { name = "pathspec" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8e/22/ea637422dedf0bf36f3ef238eab4e455e2a0dcc3082b5cc067615347ab8e/mypy-1.17.1.tar.gz", hash = "sha256:25e01ec741ab5bb3eec8ba9cdb0f769230368a22c959c4937360efb89b7e9f01", size = 3352570, upload-time = "2025-07-31T07:54:19.204Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404, upload-time = "2025-12-15T05:03:48.42Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/cf/eadc80c4e0a70db1c08921dcc220357ba8ab2faecb4392e3cebeb10edbfa/mypy-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad37544be07c5d7fba814eb370e006df58fed8ad1ef33ed1649cb1889ba6ff58", size = 10921009, upload-time = "2025-07-31T07:53:23.037Z" }, - { url = "https://files.pythonhosted.org/packages/5d/c1/c869d8c067829ad30d9bdae051046561552516cfb3a14f7f0347b7d973ee/mypy-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:064e2ff508e5464b4bd807a7c1625bc5047c5022b85c70f030680e18f37273a5", size = 10047482, upload-time = "2025-07-31T07:53:26.151Z" }, - { url = "https://files.pythonhosted.org/packages/98/b9/803672bab3fe03cee2e14786ca056efda4bb511ea02dadcedde6176d06d0/mypy-1.17.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70401bbabd2fa1aa7c43bb358f54037baf0586f41e83b0ae67dd0534fc64edfd", size = 11832883, upload-time = "2025-07-31T07:53:47.948Z" }, - { url = "https://files.pythonhosted.org/packages/88/fb/fcdac695beca66800918c18697b48833a9a6701de288452b6715a98cfee1/mypy-1.17.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e92bdc656b7757c438660f775f872a669b8ff374edc4d18277d86b63edba6b8b", size = 12566215, upload-time = "2025-07-31T07:54:04.031Z" }, - { url = "https://files.pythonhosted.org/packages/7f/37/a932da3d3dace99ee8eb2043b6ab03b6768c36eb29a02f98f46c18c0da0e/mypy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c1fdf4abb29ed1cb091cf432979e162c208a5ac676ce35010373ff29247bcad5", size = 12751956, upload-time = "2025-07-31T07:53:36.263Z" }, - { url = "https://files.pythonhosted.org/packages/8c/cf/6438a429e0f2f5cab8bc83e53dbebfa666476f40ee322e13cac5e64b79e7/mypy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:ff2933428516ab63f961644bc49bc4cbe42bbffb2cd3b71cc7277c07d16b1a8b", size = 9507307, upload-time = "2025-07-31T07:53:59.734Z" }, - { url = "https://files.pythonhosted.org/packages/17/a2/7034d0d61af8098ec47902108553122baa0f438df8a713be860f7407c9e6/mypy-1.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:69e83ea6553a3ba79c08c6e15dbd9bfa912ec1e493bf75489ef93beb65209aeb", size = 11086295, upload-time = "2025-07-31T07:53:28.124Z" }, - { url = "https://files.pythonhosted.org/packages/14/1f/19e7e44b594d4b12f6ba8064dbe136505cec813549ca3e5191e40b1d3cc2/mypy-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b16708a66d38abb1e6b5702f5c2c87e133289da36f6a1d15f6a5221085c6403", size = 10112355, upload-time = "2025-07-31T07:53:21.121Z" }, - { url = "https://files.pythonhosted.org/packages/5b/69/baa33927e29e6b4c55d798a9d44db5d394072eef2bdc18c3e2048c9ed1e9/mypy-1.17.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89e972c0035e9e05823907ad5398c5a73b9f47a002b22359b177d40bdaee7056", size = 11875285, upload-time = "2025-07-31T07:53:55.293Z" }, - { url = "https://files.pythonhosted.org/packages/90/13/f3a89c76b0a41e19490b01e7069713a30949d9a6c147289ee1521bcea245/mypy-1.17.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03b6d0ed2b188e35ee6d5c36b5580cffd6da23319991c49ab5556c023ccf1341", size = 12737895, upload-time = "2025-07-31T07:53:43.623Z" }, - { url = "https://files.pythonhosted.org/packages/23/a1/c4ee79ac484241301564072e6476c5a5be2590bc2e7bfd28220033d2ef8f/mypy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c837b896b37cd103570d776bda106eabb8737aa6dd4f248451aecf53030cdbeb", size = 12931025, upload-time = "2025-07-31T07:54:17.125Z" }, - { url = "https://files.pythonhosted.org/packages/89/b8/7409477be7919a0608900e6320b155c72caab4fef46427c5cc75f85edadd/mypy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:665afab0963a4b39dff7c1fa563cc8b11ecff7910206db4b2e64dd1ba25aed19", size = 9584664, upload-time = "2025-07-31T07:54:12.842Z" }, - { url = "https://files.pythonhosted.org/packages/1d/f3/8fcd2af0f5b806f6cf463efaffd3c9548a28f84220493ecd38d127b6b66d/mypy-1.17.1-py3-none-any.whl", hash = "sha256:a9f52c0351c21fe24c21d8c0eb1f62967b262d6729393397b6f443c3b773c3b9", size = 2283411, upload-time = "2025-07-31T07:53:24.664Z" }, + { url = "https://files.pythonhosted.org/packages/ef/47/6b3ebabd5474d9cdc170d1342fbf9dddc1b0ec13ec90bf9004ee6f391c31/mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288", size = 13028539, upload-time = "2025-12-15T05:03:44.129Z" }, + { url = "https://files.pythonhosted.org/packages/5c/a6/ac7c7a88a3c9c54334f53a941b765e6ec6c4ebd65d3fe8cdcfbe0d0fd7db/mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab", size = 12083163, upload-time = "2025-12-15T05:03:37.679Z" }, + { url = "https://files.pythonhosted.org/packages/67/af/3afa9cf880aa4a2c803798ac24f1d11ef72a0c8079689fac5cfd815e2830/mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6", size = 12687629, upload-time = "2025-12-15T05:02:31.526Z" }, + { url = "https://files.pythonhosted.org/packages/2d/46/20f8a7114a56484ab268b0ab372461cb3a8f7deed31ea96b83a4e4cfcfca/mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331", size = 13436933, upload-time = "2025-12-15T05:03:15.606Z" }, + { url = "https://files.pythonhosted.org/packages/5b/f8/33b291ea85050a21f15da910002460f1f445f8007adb29230f0adea279cb/mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925", size = 13661754, upload-time = "2025-12-15T05:02:26.731Z" }, + { url = "https://files.pythonhosted.org/packages/fd/a3/47cbd4e85bec4335a9cd80cf67dbc02be21b5d4c9c23ad6b95d6c5196bac/mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042", size = 10055772, upload-time = "2025-12-15T05:03:26.179Z" }, + { url = "https://files.pythonhosted.org/packages/06/8a/19bfae96f6615aa8a0604915512e0289b1fad33d5909bf7244f02935d33a/mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1", size = 13206053, upload-time = "2025-12-15T05:03:46.622Z" }, + { url = "https://files.pythonhosted.org/packages/a5/34/3e63879ab041602154ba2a9f99817bb0c85c4df19a23a1443c8986e4d565/mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e", size = 12219134, upload-time = "2025-12-15T05:03:24.367Z" }, + { url = "https://files.pythonhosted.org/packages/89/cc/2db6f0e95366b630364e09845672dbee0cbf0bbe753a204b29a944967cd9/mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2", size = 12731616, upload-time = "2025-12-15T05:02:44.725Z" }, + { url = "https://files.pythonhosted.org/packages/00/be/dd56c1fd4807bc1eba1cf18b2a850d0de7bacb55e158755eb79f77c41f8e/mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8", size = 13620847, upload-time = "2025-12-15T05:03:39.633Z" }, + { url = "https://files.pythonhosted.org/packages/6d/42/332951aae42b79329f743bf1da088cd75d8d4d9acc18fbcbd84f26c1af4e/mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a", size = 13834976, upload-time = "2025-12-15T05:03:08.786Z" }, + { url = "https://files.pythonhosted.org/packages/6f/63/e7493e5f90e1e085c562bb06e2eb32cae27c5057b9653348d38b47daaecc/mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13", size = 10118104, upload-time = "2025-12-15T05:03:10.834Z" }, + { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239, upload-time = "2025-12-15T05:03:07.248Z" }, ] [[package]] @@ -5078,11 +5113,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.7.5" +version = "6.8.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f6/52/37cc0aa9e9d1bf7729a737a0d83f8b3f851c8eb137373d9f71eafb0a3405/pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d", size = 5304278, upload-time = "2026-03-02T09:05:21.464Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/05/89/336673efd0a88956562658aba4f0bbef7cb92a6fbcbcaf94926dbc82b408/pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13", size = 331421, upload-time = "2026-03-02T09:05:19.722Z" }, + { url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" }, ] [[package]] @@ -5140,18 +5175,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.54.0" +version = "0.55.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/81/44/c10b16a302fda90d0af1328f880b232761b510eab546616a7be2fdf35a57/pyrefly-0.54.0.tar.gz", hash = "sha256:c6663be64d492f0d2f2a411ada9f28a6792163d34133639378b7f3dd9a8dca94", size = 5098893, upload-time = "2026-02-23T15:44:35.111Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/99/8fdcdb4e55f0227fdd9f6abce36b619bab1ecb0662b83b66adc8cba3c788/pyrefly-0.54.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:58a3f092b6dc25ef79b2dc6c69a40f36784ca157c312bfc0baea463926a9db6d", size = 12223973, upload-time = "2026-02-23T15:44:14.278Z" }, - { url = "https://files.pythonhosted.org/packages/90/35/c2aaf87a76003ad27b286594d2e5178f811eaa15bfe3d98dba2b47d56dd1/pyrefly-0.54.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:615081414106dd95873bc39c3a4bed68754c6cc24a8177ac51d22f88f88d3eb3", size = 11785585, upload-time = "2026-02-23T15:44:17.468Z" }, - { url = "https://files.pythonhosted.org/packages/c4/4a/ced02691ed67e5a897714979196f08ad279ec7ec7f63c45e00a75a7f3c0e/pyrefly-0.54.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbcaf20f5fe585079079a95205c1f3cd4542d17228cdf1df560288880623b70", size = 33381977, upload-time = "2026-02-23T15:44:19.736Z" }, - { url = "https://files.pythonhosted.org/packages/0b/ce/72a117ed437c8f6950862181014b41e36f3c3997580e29b772b71e78d587/pyrefly-0.54.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66d5da116c0d34acfbd66663addd3ca8aa78a636f6692a66e078126d3620a883", size = 35962821, upload-time = "2026-02-23T15:44:22.357Z" }, - { url = "https://files.pythonhosted.org/packages/85/de/89013f5ae0a35d2b6b01274a92a35ee91431ea001050edf0a16748d39875/pyrefly-0.54.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef3ac27f1a4baaf67aead64287d3163350844794aca6315ad1a9650b16ec26a", size = 38496689, upload-time = "2026-02-23T15:44:25.236Z" }, - { url = "https://files.pythonhosted.org/packages/9f/9a/33b097c7bf498b924742dca32dd5d9c6a3fa6c2b52b63a58eb9e1980ca89/pyrefly-0.54.0-py3-none-win32.whl", hash = "sha256:7d607d72200a8afbd2db10bfefb40160a7a5d709d207161c21649cedd5cfc09a", size = 11295268, upload-time = "2026-02-23T15:44:27.551Z" }, - { url = "https://files.pythonhosted.org/packages/d4/21/9263fd1144d2a3d7342b474f183f7785b3358a1565c864089b780110b933/pyrefly-0.54.0-py3-none-win_amd64.whl", hash = "sha256:fd416f04f89309385696f685bd5c9141011f18c8072f84d31ca20c748546e791", size = 12081810, upload-time = "2026-02-23T15:44:29.461Z" }, - { url = "https://files.pythonhosted.org/packages/ea/5b/fad062a196c064cbc8564de5b2f4d3cb6315f852e3b31e8a1ce74c69a1ea/pyrefly-0.54.0-py3-none-win_arm64.whl", hash = "sha256:f06ab371356c7b1925e0bffe193b738797e71e5dbbff7fb5a13f90ee7521211d", size = 11564930, upload-time = "2026-02-23T15:44:33.053Z" }, + { url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" }, + { url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" }, + { url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" }, + { url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" }, + { url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" }, + { url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" }, + { url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" }, ] [[package]] diff --git a/dev/pyrefly-check-local b/dev/pyrefly-check-local new file mode 100755 index 0000000000..80f90927bb --- /dev/null +++ b/dev/pyrefly-check-local @@ -0,0 +1,34 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +REPO_ROOT="$SCRIPT_DIR/.." +cd "$REPO_ROOT" + +EXCLUDES_FILE="api/pyrefly-local-excludes.txt" + +pyrefly_args=( + "--summary=none" + "--project-excludes=.venv" + "--project-excludes=migrations/" + "--project-excludes=tests/" +) + +if [[ -f "$EXCLUDES_FILE" ]]; then + while IFS= read -r exclude; do + [[ -z "$exclude" || "${exclude:0:1}" == "#" ]] && continue + pyrefly_args+=("--project-excludes=$exclude") + done < "$EXCLUDES_FILE" +fi + +tmp_output="$(mktemp)" +set +e +uv run --directory api --dev pyrefly check "${pyrefly_args[@]}" >"$tmp_output" 2>&1 +pyrefly_status=$? +set -e + +uv run --directory api python libs/pyrefly_diagnostics.py < "$tmp_output" +rm -f "$tmp_output" + +exit "$pyrefly_status" diff --git a/dev/start-worker b/dev/start-worker index 0450851b56..8baa36f1ed 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -21,6 +21,7 @@ show_help() { echo "" echo "Available queues:" echo " dataset - RAG indexing and document processing" + echo " dataset_summary - LLM-heavy summary index generation (isolated from indexing)" echo " workflow - Workflow triggers (community edition)" echo " workflow_professional - Professional tier workflows (cloud edition)" echo " workflow_team - Team tier workflows (cloud edition)" @@ -106,10 +107,10 @@ if [[ -z "${QUEUES}" ]]; then # Configure queues based on edition if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + QUEUES="dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset and workflow have separate queues - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + QUEUES="dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" fi echo "No queues specified, using edition-based defaults: ${QUEUES}" diff --git a/docs/tlh/README.md b/docs/tlh/README.md index a25849c443..e2acd7734c 100644 --- a/docs/tlh/README.md +++ b/docs/tlh/README.md @@ -61,7 +61,7 @@

langgenius%2Fdify | Trendshift

-Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features: +Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features:

**1. Workflow**: diff --git a/scripts/stress-test/common/config_helper.py b/scripts/stress-test/common/config_helper.py index 75fcbffa6f..fb34b43e26 100644 --- a/scripts/stress-test/common/config_helper.py +++ b/scripts/stress-test/common/config_helper.py @@ -6,6 +6,13 @@ from typing import Any class ConfigHelper: + _LEGACY_SECTION_MAP = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + """Helper class for reading and writing configuration files.""" def __init__(self, base_dir: Path | None = None): @@ -50,14 +57,8 @@ class ConfigHelper: Dictionary containing config data, or None if file doesn't exist """ # Provide backward compatibility for old config names - if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: - section_map = { - "admin_config": "admin", - "token_config": "auth", - "app_config": "app", - "api_key_config": "api_key", - } - return self.get_state_section(section_map[filename]) + if filename in self._LEGACY_SECTION_MAP: + return self.get_state_section(self._LEGACY_SECTION_MAP[filename]) config_path = self.get_config_path(filename) @@ -85,14 +86,11 @@ class ConfigHelper: True if successful, False otherwise """ # Provide backward compatibility for old config names - if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: - section_map = { - "admin_config": "admin", - "token_config": "auth", - "app_config": "app", - "api_key_config": "api_key", - } - return self.update_state_section(section_map[filename], data) + if filename in self._LEGACY_SECTION_MAP: + return self.update_state_section( + self._LEGACY_SECTION_MAP[filename], + data, + ) self.ensure_config_dir() config_path = self.get_config_path(filename) diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index afbb58fee1..7c8a293446 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -54,17 +54,22 @@ "publish:npm": "./scripts/publish.sh" }, "dependencies": { - "axios": "^1.13.2" + "axios": "^1.13.6" }, "devDependencies": { - "@eslint/js": "^9.39.2", - "@types/node": "^25.0.3", - "@typescript-eslint/eslint-plugin": "^8.50.1", - "@typescript-eslint/parser": "^8.50.1", - "@vitest/coverage-v8": "4.0.16", - "eslint": "^9.39.2", + "@eslint/js": "^10.0.1", + "@types/node": "^25.4.0", + "@typescript-eslint/eslint-plugin": "^8.57.0", + "@typescript-eslint/parser": "^8.57.0", + "@vitest/coverage-v8": "4.0.18", + "eslint": "^10.0.3", "tsup": "^8.5.1", "typescript": "^5.9.3", - "vitest": "^4.0.16" + "vitest": "^4.0.18" + }, + "pnpm": { + "overrides": { + "rollup@>=4.0.0,<4.59.0": "4.59.0" + } } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index 1923a0f063..b0aee38cdf 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -4,41 +4,44 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false +overrides: + rollup@>=4.0.0,<4.59.0: 4.59.0 + importers: .: dependencies: axios: - specifier: ^1.13.2 - version: 1.13.5 + specifier: ^1.13.6 + version: 1.13.6 devDependencies: '@eslint/js': - specifier: ^9.39.2 - version: 9.39.2 + specifier: ^10.0.1 + version: 10.0.1(eslint@10.0.3) '@types/node': - specifier: ^25.0.3 - version: 25.0.3 + specifier: ^25.4.0 + version: 25.4.0 '@typescript-eslint/eslint-plugin': - specifier: ^8.50.1 - version: 8.50.1(@typescript-eslint/parser@8.50.1(eslint@9.39.2)(typescript@5.9.3))(eslint@9.39.2)(typescript@5.9.3) + specifier: ^8.57.0 + version: 8.57.0(@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3))(eslint@10.0.3)(typescript@5.9.3) '@typescript-eslint/parser': - specifier: ^8.50.1 - version: 8.50.1(eslint@9.39.2)(typescript@5.9.3) + specifier: ^8.57.0 + version: 8.57.0(eslint@10.0.3)(typescript@5.9.3) '@vitest/coverage-v8': - specifier: 4.0.16 - version: 4.0.16(vitest@4.0.16(@types/node@25.0.3)) + specifier: 4.0.18 + version: 4.0.18(vitest@4.0.18(@types/node@25.4.0)) eslint: - specifier: ^9.39.2 - version: 9.39.2 + specifier: ^10.0.3 + version: 10.0.3 tsup: specifier: ^8.5.1 - version: 8.5.1(postcss@8.5.6)(typescript@5.9.3) + version: 8.5.1(postcss@8.5.8)(typescript@5.9.3) typescript: specifier: ^5.9.3 version: 5.9.3 vitest: - specifier: ^4.0.16 - version: 4.0.16(@types/node@25.0.3) + specifier: ^4.0.18 + version: 4.0.18(@types/node@25.4.0) packages: @@ -50,177 +53,177 @@ packages: resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==} engines: {node: '>=6.9.0'} - '@babel/parser@7.28.5': - resolution: {integrity: sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==} + '@babel/parser@7.29.0': + resolution: {integrity: sha512-IyDgFV5GeDUVX4YdF/3CPULtVGSXXMLh1xVIgdCgxApktqnQV0r7/8Nqthg+8YLGaAtdyIlo2qIdZrbCv4+7ww==} engines: {node: '>=6.0.0'} hasBin: true - '@babel/types@7.28.5': - resolution: {integrity: sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==} + '@babel/types@7.29.0': + resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==} engines: {node: '>=6.9.0'} '@bcoe/v8-coverage@1.0.2': resolution: {integrity: sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==} engines: {node: '>=18'} - '@esbuild/aix-ppc64@0.27.2': - resolution: {integrity: sha512-GZMB+a0mOMZs4MpDbj8RJp4cw+w1WV5NYD6xzgvzUJ5Ek2jerwfO2eADyI6ExDSUED+1X8aMbegahsJi+8mgpw==} + '@esbuild/aix-ppc64@0.27.3': + resolution: {integrity: sha512-9fJMTNFTWZMh5qwrBItuziu834eOCUcEqymSH7pY+zoMVEZg3gcPuBNxH1EvfVYe9h0x/Ptw8KBzv7qxb7l8dg==} engines: {node: '>=18'} cpu: [ppc64] os: [aix] - '@esbuild/android-arm64@0.27.2': - resolution: {integrity: sha512-pvz8ZZ7ot/RBphf8fv60ljmaoydPU12VuXHImtAs0XhLLw+EXBi2BLe3OYSBslR4rryHvweW5gmkKFwTiFy6KA==} + '@esbuild/android-arm64@0.27.3': + resolution: {integrity: sha512-YdghPYUmj/FX2SYKJ0OZxf+iaKgMsKHVPF1MAq/P8WirnSpCStzKJFjOjzsW0QQ7oIAiccHdcqjbHmJxRb/dmg==} engines: {node: '>=18'} cpu: [arm64] os: [android] - '@esbuild/android-arm@0.27.2': - resolution: {integrity: sha512-DVNI8jlPa7Ujbr1yjU2PfUSRtAUZPG9I1RwW4F4xFB1Imiu2on0ADiI/c3td+KmDtVKNbi+nffGDQMfcIMkwIA==} + '@esbuild/android-arm@0.27.3': + resolution: {integrity: sha512-i5D1hPY7GIQmXlXhs2w8AWHhenb00+GxjxRncS2ZM7YNVGNfaMxgzSGuO8o8SJzRc/oZwU2bcScvVERk03QhzA==} engines: {node: '>=18'} cpu: [arm] os: [android] - '@esbuild/android-x64@0.27.2': - resolution: {integrity: sha512-z8Ank4Byh4TJJOh4wpz8g2vDy75zFL0TlZlkUkEwYXuPSgX8yzep596n6mT7905kA9uHZsf/o2OJZubl2l3M7A==} + '@esbuild/android-x64@0.27.3': + resolution: {integrity: sha512-IN/0BNTkHtk8lkOM8JWAYFg4ORxBkZQf9zXiEOfERX/CzxW3Vg1ewAhU7QSWQpVIzTW+b8Xy+lGzdYXV6UZObQ==} engines: {node: '>=18'} cpu: [x64] os: [android] - '@esbuild/darwin-arm64@0.27.2': - resolution: {integrity: sha512-davCD2Zc80nzDVRwXTcQP/28fiJbcOwvdolL0sOiOsbwBa72kegmVU0Wrh1MYrbuCL98Omp5dVhQFWRKR2ZAlg==} + '@esbuild/darwin-arm64@0.27.3': + resolution: {integrity: sha512-Re491k7ByTVRy0t3EKWajdLIr0gz2kKKfzafkth4Q8A5n1xTHrkqZgLLjFEHVD+AXdUGgQMq+Godfq45mGpCKg==} engines: {node: '>=18'} cpu: [arm64] os: [darwin] - '@esbuild/darwin-x64@0.27.2': - resolution: {integrity: sha512-ZxtijOmlQCBWGwbVmwOF/UCzuGIbUkqB1faQRf5akQmxRJ1ujusWsb3CVfk/9iZKr2L5SMU5wPBi1UWbvL+VQA==} + '@esbuild/darwin-x64@0.27.3': + resolution: {integrity: sha512-vHk/hA7/1AckjGzRqi6wbo+jaShzRowYip6rt6q7VYEDX4LEy1pZfDpdxCBnGtl+A5zq8iXDcyuxwtv3hNtHFg==} engines: {node: '>=18'} cpu: [x64] os: [darwin] - '@esbuild/freebsd-arm64@0.27.2': - resolution: {integrity: sha512-lS/9CN+rgqQ9czogxlMcBMGd+l8Q3Nj1MFQwBZJyoEKI50XGxwuzznYdwcav6lpOGv5BqaZXqvBSiB/kJ5op+g==} + '@esbuild/freebsd-arm64@0.27.3': + resolution: {integrity: sha512-ipTYM2fjt3kQAYOvo6vcxJx3nBYAzPjgTCk7QEgZG8AUO3ydUhvelmhrbOheMnGOlaSFUoHXB6un+A7q4ygY9w==} engines: {node: '>=18'} cpu: [arm64] os: [freebsd] - '@esbuild/freebsd-x64@0.27.2': - resolution: {integrity: sha512-tAfqtNYb4YgPnJlEFu4c212HYjQWSO/w/h/lQaBK7RbwGIkBOuNKQI9tqWzx7Wtp7bTPaGC6MJvWI608P3wXYA==} + '@esbuild/freebsd-x64@0.27.3': + resolution: {integrity: sha512-dDk0X87T7mI6U3K9VjWtHOXqwAMJBNN2r7bejDsc+j03SEjtD9HrOl8gVFByeM0aJksoUuUVU9TBaZa2rgj0oA==} engines: {node: '>=18'} cpu: [x64] os: [freebsd] - '@esbuild/linux-arm64@0.27.2': - resolution: {integrity: sha512-hYxN8pr66NsCCiRFkHUAsxylNOcAQaxSSkHMMjcpx0si13t1LHFphxJZUiGwojB1a/Hd5OiPIqDdXONia6bhTw==} + '@esbuild/linux-arm64@0.27.3': + resolution: {integrity: sha512-sZOuFz/xWnZ4KH3YfFrKCf1WyPZHakVzTiqji3WDc0BCl2kBwiJLCXpzLzUBLgmp4veFZdvN5ChW4Eq/8Fc2Fg==} engines: {node: '>=18'} cpu: [arm64] os: [linux] - '@esbuild/linux-arm@0.27.2': - resolution: {integrity: sha512-vWfq4GaIMP9AIe4yj1ZUW18RDhx6EPQKjwe7n8BbIecFtCQG4CfHGaHuh7fdfq+y3LIA2vGS/o9ZBGVxIDi9hw==} + '@esbuild/linux-arm@0.27.3': + resolution: {integrity: sha512-s6nPv2QkSupJwLYyfS+gwdirm0ukyTFNl3KTgZEAiJDd+iHZcbTPPcWCcRYH+WlNbwChgH2QkE9NSlNrMT8Gfw==} engines: {node: '>=18'} cpu: [arm] os: [linux] - '@esbuild/linux-ia32@0.27.2': - resolution: {integrity: sha512-MJt5BRRSScPDwG2hLelYhAAKh9imjHK5+NE/tvnRLbIqUWa+0E9N4WNMjmp/kXXPHZGqPLxggwVhz7QP8CTR8w==} + '@esbuild/linux-ia32@0.27.3': + resolution: {integrity: sha512-yGlQYjdxtLdh0a3jHjuwOrxQjOZYD/C9PfdbgJJF3TIZWnm/tMd/RcNiLngiu4iwcBAOezdnSLAwQDPqTmtTYg==} engines: {node: '>=18'} cpu: [ia32] os: [linux] - '@esbuild/linux-loong64@0.27.2': - resolution: {integrity: sha512-lugyF1atnAT463aO6KPshVCJK5NgRnU4yb3FUumyVz+cGvZbontBgzeGFO1nF+dPueHD367a2ZXe1NtUkAjOtg==} + '@esbuild/linux-loong64@0.27.3': + resolution: {integrity: sha512-WO60Sn8ly3gtzhyjATDgieJNet/KqsDlX5nRC5Y3oTFcS1l0KWba+SEa9Ja1GfDqSF1z6hif/SkpQJbL63cgOA==} engines: {node: '>=18'} cpu: [loong64] os: [linux] - '@esbuild/linux-mips64el@0.27.2': - resolution: {integrity: sha512-nlP2I6ArEBewvJ2gjrrkESEZkB5mIoaTswuqNFRv/WYd+ATtUpe9Y09RnJvgvdag7he0OWgEZWhviS1OTOKixw==} + '@esbuild/linux-mips64el@0.27.3': + resolution: {integrity: sha512-APsymYA6sGcZ4pD6k+UxbDjOFSvPWyZhjaiPyl/f79xKxwTnrn5QUnXR5prvetuaSMsb4jgeHewIDCIWljrSxw==} engines: {node: '>=18'} cpu: [mips64el] os: [linux] - '@esbuild/linux-ppc64@0.27.2': - resolution: {integrity: sha512-C92gnpey7tUQONqg1n6dKVbx3vphKtTHJaNG2Ok9lGwbZil6DrfyecMsp9CrmXGQJmZ7iiVXvvZH6Ml5hL6XdQ==} + '@esbuild/linux-ppc64@0.27.3': + resolution: {integrity: sha512-eizBnTeBefojtDb9nSh4vvVQ3V9Qf9Df01PfawPcRzJH4gFSgrObw+LveUyDoKU3kxi5+9RJTCWlj4FjYXVPEA==} engines: {node: '>=18'} cpu: [ppc64] os: [linux] - '@esbuild/linux-riscv64@0.27.2': - resolution: {integrity: sha512-B5BOmojNtUyN8AXlK0QJyvjEZkWwy/FKvakkTDCziX95AowLZKR6aCDhG7LeF7uMCXEJqwa8Bejz5LTPYm8AvA==} + '@esbuild/linux-riscv64@0.27.3': + resolution: {integrity: sha512-3Emwh0r5wmfm3ssTWRQSyVhbOHvqegUDRd0WhmXKX2mkHJe1SFCMJhagUleMq+Uci34wLSipf8Lagt4LlpRFWQ==} engines: {node: '>=18'} cpu: [riscv64] os: [linux] - '@esbuild/linux-s390x@0.27.2': - resolution: {integrity: sha512-p4bm9+wsPwup5Z8f4EpfN63qNagQ47Ua2znaqGH6bqLlmJ4bx97Y9JdqxgGZ6Y8xVTixUnEkoKSHcpRlDnNr5w==} + '@esbuild/linux-s390x@0.27.3': + resolution: {integrity: sha512-pBHUx9LzXWBc7MFIEEL0yD/ZVtNgLytvx60gES28GcWMqil8ElCYR4kvbV2BDqsHOvVDRrOxGySBM9Fcv744hw==} engines: {node: '>=18'} cpu: [s390x] os: [linux] - '@esbuild/linux-x64@0.27.2': - resolution: {integrity: sha512-uwp2Tip5aPmH+NRUwTcfLb+W32WXjpFejTIOWZFw/v7/KnpCDKG66u4DLcurQpiYTiYwQ9B7KOeMJvLCu/OvbA==} + '@esbuild/linux-x64@0.27.3': + resolution: {integrity: sha512-Czi8yzXUWIQYAtL/2y6vogER8pvcsOsk5cpwL4Gk5nJqH5UZiVByIY8Eorm5R13gq+DQKYg0+JyQoytLQas4dA==} engines: {node: '>=18'} cpu: [x64] os: [linux] - '@esbuild/netbsd-arm64@0.27.2': - resolution: {integrity: sha512-Kj6DiBlwXrPsCRDeRvGAUb/LNrBASrfqAIok+xB0LxK8CHqxZ037viF13ugfsIpePH93mX7xfJp97cyDuTZ3cw==} + '@esbuild/netbsd-arm64@0.27.3': + resolution: {integrity: sha512-sDpk0RgmTCR/5HguIZa9n9u+HVKf40fbEUt+iTzSnCaGvY9kFP0YKBWZtJaraonFnqef5SlJ8/TiPAxzyS+UoA==} engines: {node: '>=18'} cpu: [arm64] os: [netbsd] - '@esbuild/netbsd-x64@0.27.2': - resolution: {integrity: sha512-HwGDZ0VLVBY3Y+Nw0JexZy9o/nUAWq9MlV7cahpaXKW6TOzfVno3y3/M8Ga8u8Yr7GldLOov27xiCnqRZf0tCA==} + '@esbuild/netbsd-x64@0.27.3': + resolution: {integrity: sha512-P14lFKJl/DdaE00LItAukUdZO5iqNH7+PjoBm+fLQjtxfcfFE20Xf5CrLsmZdq5LFFZzb5JMZ9grUwvtVYzjiA==} engines: {node: '>=18'} cpu: [x64] os: [netbsd] - '@esbuild/openbsd-arm64@0.27.2': - resolution: {integrity: sha512-DNIHH2BPQ5551A7oSHD0CKbwIA/Ox7+78/AWkbS5QoRzaqlev2uFayfSxq68EkonB+IKjiuxBFoV8ESJy8bOHA==} + '@esbuild/openbsd-arm64@0.27.3': + resolution: {integrity: sha512-AIcMP77AvirGbRl/UZFTq5hjXK+2wC7qFRGoHSDrZ5v5b8DK/GYpXW3CPRL53NkvDqb9D+alBiC/dV0Fb7eJcw==} engines: {node: '>=18'} cpu: [arm64] os: [openbsd] - '@esbuild/openbsd-x64@0.27.2': - resolution: {integrity: sha512-/it7w9Nb7+0KFIzjalNJVR5bOzA9Vay+yIPLVHfIQYG/j+j9VTH84aNB8ExGKPU4AzfaEvN9/V4HV+F+vo8OEg==} + '@esbuild/openbsd-x64@0.27.3': + resolution: {integrity: sha512-DnW2sRrBzA+YnE70LKqnM3P+z8vehfJWHXECbwBmH/CU51z6FiqTQTHFenPlHmo3a8UgpLyH3PT+87OViOh1AQ==} engines: {node: '>=18'} cpu: [x64] os: [openbsd] - '@esbuild/openharmony-arm64@0.27.2': - resolution: {integrity: sha512-LRBbCmiU51IXfeXk59csuX/aSaToeG7w48nMwA6049Y4J4+VbWALAuXcs+qcD04rHDuSCSRKdmY63sruDS5qag==} + '@esbuild/openharmony-arm64@0.27.3': + resolution: {integrity: sha512-NinAEgr/etERPTsZJ7aEZQvvg/A6IsZG/LgZy+81wON2huV7SrK3e63dU0XhyZP4RKGyTm7aOgmQk0bGp0fy2g==} engines: {node: '>=18'} cpu: [arm64] os: [openharmony] - '@esbuild/sunos-x64@0.27.2': - resolution: {integrity: sha512-kMtx1yqJHTmqaqHPAzKCAkDaKsffmXkPHThSfRwZGyuqyIeBvf08KSsYXl+abf5HDAPMJIPnbBfXvP2ZC2TfHg==} + '@esbuild/sunos-x64@0.27.3': + resolution: {integrity: sha512-PanZ+nEz+eWoBJ8/f8HKxTTD172SKwdXebZ0ndd953gt1HRBbhMsaNqjTyYLGLPdoWHy4zLU7bDVJztF5f3BHA==} engines: {node: '>=18'} cpu: [x64] os: [sunos] - '@esbuild/win32-arm64@0.27.2': - resolution: {integrity: sha512-Yaf78O/B3Kkh+nKABUF++bvJv5Ijoy9AN1ww904rOXZFLWVc5OLOfL56W+C8F9xn5JQZa3UX6m+IktJnIb1Jjg==} + '@esbuild/win32-arm64@0.27.3': + resolution: {integrity: sha512-B2t59lWWYrbRDw/tjiWOuzSsFh1Y/E95ofKz7rIVYSQkUYBjfSgf6oeYPNWHToFRr2zx52JKApIcAS/D5TUBnA==} engines: {node: '>=18'} cpu: [arm64] os: [win32] - '@esbuild/win32-ia32@0.27.2': - resolution: {integrity: sha512-Iuws0kxo4yusk7sw70Xa2E2imZU5HoixzxfGCdxwBdhiDgt9vX9VUCBhqcwY7/uh//78A1hMkkROMJq9l27oLQ==} + '@esbuild/win32-ia32@0.27.3': + resolution: {integrity: sha512-QLKSFeXNS8+tHW7tZpMtjlNb7HKau0QDpwm49u0vUp9y1WOF+PEzkU84y9GqYaAVW8aH8f3GcBck26jh54cX4Q==} engines: {node: '>=18'} cpu: [ia32] os: [win32] - '@esbuild/win32-x64@0.27.2': - resolution: {integrity: sha512-sRdU18mcKf7F+YgheI/zGf5alZatMUTKj/jNS6l744f9u3WFu4v7twcUI9vu4mknF4Y9aDlblIie0IM+5xxaqQ==} + '@esbuild/win32-x64@0.27.3': + resolution: {integrity: sha512-4uJGhsxuptu3OcpVAzli+/gWusVGwZZHTlS63hh++ehExkVT8SgiEf7/uC/PclrPPkLhZqGgCTjd0VWLo6xMqA==} engines: {node: '>=18'} cpu: [x64] os: [win32] - '@eslint-community/eslint-utils@4.9.0': - resolution: {integrity: sha512-ayVFHdtZ+hsq1t2Dy24wCmGXGe4q9Gu3smhLYALJrr473ZH27MsnSL+LKUlimp4BWJqMDMLmPpx/Q9R3OAlL4g==} + '@eslint-community/eslint-utils@4.9.1': + resolution: {integrity: sha512-phrYmNiYppR7znFEdqgfWHXR6NCkZEK7hwWDHZUjit/2/U0r6XvkDl0SYnoM51Hq7FhCGdLDT6zxCCOY1hexsQ==} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} peerDependencies: eslint: ^6.0.0 || ^7.0.0 || >=8.0.0 @@ -229,33 +232,34 @@ packages: resolution: {integrity: sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==} engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} - '@eslint/config-array@0.21.1': - resolution: {integrity: sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/config-array@0.23.3': + resolution: {integrity: sha512-j+eEWmB6YYLwcNOdlwQ6L2OsptI/LO6lNBuLIqe5R7RetD658HLoF+Mn7LzYmAWWNNzdC6cqP+L6r8ujeYXWLw==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/config-helpers@0.4.2': - resolution: {integrity: sha512-gBrxN88gOIf3R7ja5K9slwNayVcZgK6SOUORm2uBzTeIEfeVaIhOpCtTox3P6R7o2jLFwLFTLnC7kU/RGcYEgw==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/config-helpers@0.5.3': + resolution: {integrity: sha512-lzGN0onllOZCGroKJmRwY6QcEHxbjBw1gwB8SgRSqK8YbbtEXMvKynsXc3553ckIEBxsbMBU7oOZXKIPGZNeZw==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/core@0.17.0': - resolution: {integrity: sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/core@1.1.1': + resolution: {integrity: sha512-QUPblTtE51/7/Zhfv8BDwO0qkkzQL7P/aWWbqcf4xWLEYn1oKjdO0gglQBB4GAsu7u6wjijbCmzsUTy6mnk6oQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/eslintrc@3.3.3': - resolution: {integrity: sha512-Kr+LPIUVKz2qkx1HAMH8q1q6azbqBAsXJUxBl/ODDuVPX45Z9DfwB8tPjTi6nNZ8BuM3nbJxC5zCAg5elnBUTQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/js@10.0.1': + resolution: {integrity: sha512-zeR9k5pd4gxjZ0abRoIaxdc7I3nDktoXZk2qOv9gCNWx3mVwEn32VRhyLaRsDiJjTs0xq/T8mfPtyuXu7GWBcA==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} + peerDependencies: + eslint: ^10.0.0 + peerDependenciesMeta: + eslint: + optional: true - '@eslint/js@9.39.2': - resolution: {integrity: sha512-q1mjIoW1VX4IvSocvM/vbTiveKC4k9eLrajNEuSsmjymSDEbpGddtpfOoN7YGAqBK3NG+uqo8ia4PDTt8buCYA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/object-schema@3.0.3': + resolution: {integrity: sha512-iM869Pugn9Nsxbh/YHRqYiqd23AmIbxJOcpUMOuWCVNdoQJ5ZtwL6h3t0bcZzJUlC3Dq9jCFCESBZnX0GTv7iQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/object-schema@2.1.7': - resolution: {integrity: sha512-VtAOaymWVfZcmZbp6E2mympDIHvyjXs/12LqWYjVw6qjrfF+VK+fyG33kChz3nnK+SU5/NeHOqrTEHS8sXO3OA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - - '@eslint/plugin-kit@0.4.1': - resolution: {integrity: sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/plugin-kit@0.6.1': + resolution: {integrity: sha512-iH1B076HoAshH1mLpHMgwdGeTs0CYwL0SPMkGuSebZrwBp16v415e9NZXg2jtrqPVQjf6IANe2Vtlr5KswtcZQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} '@humanfs/core@0.19.1': resolution: {integrity: sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==} @@ -286,113 +290,141 @@ packages: '@jridgewell/trace-mapping@0.3.31': resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==} - '@rollup/rollup-android-arm-eabi@4.54.0': - resolution: {integrity: sha512-OywsdRHrFvCdvsewAInDKCNyR3laPA2mc9bRYJ6LBp5IyvF3fvXbbNR0bSzHlZVFtn6E0xw2oZlyjg4rKCVcng==} + '@rollup/rollup-android-arm-eabi@4.59.0': + resolution: {integrity: sha512-upnNBkA6ZH2VKGcBj9Fyl9IGNPULcjXRlg0LLeaioQWueH30p6IXtJEbKAgvyv+mJaMxSm1l6xwDXYjpEMiLMg==} cpu: [arm] os: [android] - '@rollup/rollup-android-arm64@4.54.0': - resolution: {integrity: sha512-Skx39Uv+u7H224Af+bDgNinitlmHyQX1K/atIA32JP3JQw6hVODX5tkbi2zof/E69M1qH2UoN3Xdxgs90mmNYw==} + '@rollup/rollup-android-arm64@4.59.0': + resolution: {integrity: sha512-hZ+Zxj3SySm4A/DylsDKZAeVg0mvi++0PYVceVyX7hemkw7OreKdCvW2oQ3T1FMZvCaQXqOTHb8qmBShoqk69Q==} cpu: [arm64] os: [android] - '@rollup/rollup-darwin-arm64@4.54.0': - resolution: {integrity: sha512-k43D4qta/+6Fq+nCDhhv9yP2HdeKeP56QrUUTW7E6PhZP1US6NDqpJj4MY0jBHlJivVJD5P8NxrjuobZBJTCRw==} + '@rollup/rollup-darwin-arm64@4.59.0': + resolution: {integrity: sha512-W2Psnbh1J8ZJw0xKAd8zdNgF9HRLkdWwwdWqubSVk0pUuQkoHnv7rx4GiF9rT4t5DIZGAsConRE3AxCdJ4m8rg==} cpu: [arm64] os: [darwin] - '@rollup/rollup-darwin-x64@4.54.0': - resolution: {integrity: sha512-cOo7biqwkpawslEfox5Vs8/qj83M/aZCSSNIWpVzfU2CYHa2G3P1UN5WF01RdTHSgCkri7XOlTdtk17BezlV3A==} + '@rollup/rollup-darwin-x64@4.59.0': + resolution: {integrity: sha512-ZW2KkwlS4lwTv7ZVsYDiARfFCnSGhzYPdiOU4IM2fDbL+QGlyAbjgSFuqNRbSthybLbIJ915UtZBtmuLrQAT/w==} cpu: [x64] os: [darwin] - '@rollup/rollup-freebsd-arm64@4.54.0': - resolution: {integrity: sha512-miSvuFkmvFbgJ1BevMa4CPCFt5MPGw094knM64W9I0giUIMMmRYcGW/JWZDriaw/k1kOBtsWh1z6nIFV1vPNtA==} + '@rollup/rollup-freebsd-arm64@4.59.0': + resolution: {integrity: sha512-EsKaJ5ytAu9jI3lonzn3BgG8iRBjV4LxZexygcQbpiU0wU0ATxhNVEpXKfUa0pS05gTcSDMKpn3Sx+QB9RlTTA==} cpu: [arm64] os: [freebsd] - '@rollup/rollup-freebsd-x64@4.54.0': - resolution: {integrity: sha512-KGXIs55+b/ZfZsq9aR026tmr/+7tq6VG6MsnrvF4H8VhwflTIuYh+LFUlIsRdQSgrgmtM3fVATzEAj4hBQlaqQ==} + '@rollup/rollup-freebsd-x64@4.59.0': + resolution: {integrity: sha512-d3DuZi2KzTMjImrxoHIAODUZYoUUMsuUiY4SRRcJy6NJoZ6iIqWnJu9IScV9jXysyGMVuW+KNzZvBLOcpdl3Vg==} cpu: [x64] os: [freebsd] - '@rollup/rollup-linux-arm-gnueabihf@4.54.0': - resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==} + '@rollup/rollup-linux-arm-gnueabihf@4.59.0': + resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} cpu: [arm] os: [linux] + libc: [glibc] - '@rollup/rollup-linux-arm-musleabihf@4.54.0': - resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==} + '@rollup/rollup-linux-arm-musleabihf@4.59.0': + resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} cpu: [arm] os: [linux] + libc: [musl] - '@rollup/rollup-linux-arm64-gnu@4.54.0': - resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==} + '@rollup/rollup-linux-arm64-gnu@4.59.0': + resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} cpu: [arm64] os: [linux] + libc: [glibc] - '@rollup/rollup-linux-arm64-musl@4.54.0': - resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==} + '@rollup/rollup-linux-arm64-musl@4.59.0': + resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} cpu: [arm64] os: [linux] + libc: [musl] - '@rollup/rollup-linux-loong64-gnu@4.54.0': - resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==} + '@rollup/rollup-linux-loong64-gnu@4.59.0': + resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} cpu: [loong64] os: [linux] + libc: [glibc] - '@rollup/rollup-linux-ppc64-gnu@4.54.0': - resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==} + '@rollup/rollup-linux-loong64-musl@4.59.0': + resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} + cpu: [loong64] + os: [linux] + libc: [musl] + + '@rollup/rollup-linux-ppc64-gnu@4.59.0': + resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} cpu: [ppc64] os: [linux] + libc: [glibc] - '@rollup/rollup-linux-riscv64-gnu@4.54.0': - resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==} + '@rollup/rollup-linux-ppc64-musl@4.59.0': + resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} + cpu: [ppc64] + os: [linux] + libc: [musl] + + '@rollup/rollup-linux-riscv64-gnu@4.59.0': + resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} cpu: [riscv64] os: [linux] + libc: [glibc] - '@rollup/rollup-linux-riscv64-musl@4.54.0': - resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==} + '@rollup/rollup-linux-riscv64-musl@4.59.0': + resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} cpu: [riscv64] os: [linux] + libc: [musl] - '@rollup/rollup-linux-s390x-gnu@4.54.0': - resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==} + '@rollup/rollup-linux-s390x-gnu@4.59.0': + resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} cpu: [s390x] os: [linux] + libc: [glibc] - '@rollup/rollup-linux-x64-gnu@4.54.0': - resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==} + '@rollup/rollup-linux-x64-gnu@4.59.0': + resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} cpu: [x64] os: [linux] + libc: [glibc] - '@rollup/rollup-linux-x64-musl@4.54.0': - resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==} + '@rollup/rollup-linux-x64-musl@4.59.0': + resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} cpu: [x64] os: [linux] + libc: [musl] - '@rollup/rollup-openharmony-arm64@4.54.0': - resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==} + '@rollup/rollup-openbsd-x64@4.59.0': + resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} + cpu: [x64] + os: [openbsd] + + '@rollup/rollup-openharmony-arm64@4.59.0': + resolution: {integrity: sha512-tt9KBJqaqp5i5HUZzoafHZX8b5Q2Fe7UjYERADll83O4fGqJ49O1FsL6LpdzVFQcpwvnyd0i+K/VSwu/o/nWlA==} cpu: [arm64] os: [openharmony] - '@rollup/rollup-win32-arm64-msvc@4.54.0': - resolution: {integrity: sha512-c2V0W1bsKIKfbLMBu/WGBz6Yci8nJ/ZJdheE0EwB73N3MvHYKiKGs3mVilX4Gs70eGeDaMqEob25Tw2Gb9Nqyw==} + '@rollup/rollup-win32-arm64-msvc@4.59.0': + resolution: {integrity: sha512-V5B6mG7OrGTwnxaNUzZTDTjDS7F75PO1ae6MJYdiMu60sq0CqN5CVeVsbhPxalupvTX8gXVSU9gq+Rx1/hvu6A==} cpu: [arm64] os: [win32] - '@rollup/rollup-win32-ia32-msvc@4.54.0': - resolution: {integrity: sha512-woEHgqQqDCkAzrDhvDipnSirm5vxUXtSKDYTVpZG3nUdW/VVB5VdCYA2iReSj/u3yCZzXID4kuKG7OynPnB3WQ==} + '@rollup/rollup-win32-ia32-msvc@4.59.0': + resolution: {integrity: sha512-UKFMHPuM9R0iBegwzKF4y0C4J9u8C6MEJgFuXTBerMk7EJ92GFVFYBfOZaSGLu6COf7FxpQNqhNS4c4icUPqxA==} cpu: [ia32] os: [win32] - '@rollup/rollup-win32-x64-gnu@4.54.0': - resolution: {integrity: sha512-dzAc53LOuFvHwbCEOS0rPbXp6SIhAf2txMP5p6mGyOXXw5mWY8NGGbPMPrs4P1WItkfApDathBj/NzMLUZ9rtQ==} + '@rollup/rollup-win32-x64-gnu@4.59.0': + resolution: {integrity: sha512-laBkYlSS1n2L8fSo1thDNGrCTQMmxjYY5G0WFWjFFYZkKPjsMBsgJfGf4TLxXrF6RyhI60L8TMOjBMvXiTcxeA==} cpu: [x64] os: [win32] - '@rollup/rollup-win32-x64-msvc@4.54.0': - resolution: {integrity: sha512-hYT5d3YNdSh3mbCU1gwQyPgQd3T2ne0A3KG8KSBdav5TiBg6eInVmV+TeR5uHufiIgSFg0XsOWGW5/RhNcSvPg==} + '@rollup/rollup-win32-x64-msvc@4.59.0': + resolution: {integrity: sha512-2HRCml6OztYXyJXAvdDXPKcawukWY2GpR5/nxKp4iBgiO3wcoEGkAaqctIbZcNB6KlUQBIqt8VYkNSj2397EfA==} cpu: [x64] os: [win32] @@ -405,88 +437,91 @@ packages: '@types/deep-eql@4.0.2': resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} + '@types/esrecurse@4.3.1': + resolution: {integrity: sha512-xJBAbDifo5hpffDBuHl0Y8ywswbiAp/Wi7Y/GtAgSlZyIABppyurxVueOPE8LUQOxdlgi6Zqce7uoEpqNTeiUw==} + '@types/estree@1.0.8': resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==} '@types/json-schema@7.0.15': resolution: {integrity: sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==} - '@types/node@25.0.3': - resolution: {integrity: sha512-W609buLVRVmeW693xKfzHeIV6nJGGz98uCPfeXI1ELMLXVeKYZ9m15fAMSaUPBHYLGFsVRcMmSCksQOrZV9BYA==} + '@types/node@25.4.0': + resolution: {integrity: sha512-9wLpoeWuBlcbBpOY3XmzSTG3oscB6xjBEEtn+pYXTfhyXhIxC5FsBer2KTopBlvKEiW9l13po9fq+SJY/5lkhw==} - '@typescript-eslint/eslint-plugin@8.50.1': - resolution: {integrity: sha512-PKhLGDq3JAg0Jk/aK890knnqduuI/Qj+udH7wCf0217IGi4gt+acgCyPVe79qoT+qKUvHMDQkwJeKW9fwl8Cyw==} + '@typescript-eslint/eslint-plugin@8.57.0': + resolution: {integrity: sha512-qeu4rTHR3/IaFORbD16gmjq9+rEs9fGKdX0kF6BKSfi+gCuG3RCKLlSBYzn/bGsY9Tj7KE/DAQStbp8AHJGHEQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - '@typescript-eslint/parser': ^8.50.1 - eslint: ^8.57.0 || ^9.0.0 + '@typescript-eslint/parser': ^8.57.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/parser@8.50.1': - resolution: {integrity: sha512-hM5faZwg7aVNa819m/5r7D0h0c9yC4DUlWAOvHAtISdFTc8xB86VmX5Xqabrama3wIPJ/q9RbGS1worb6JfnMg==} + '@typescript-eslint/parser@8.57.0': + resolution: {integrity: sha512-XZzOmihLIr8AD1b9hL9ccNMzEMWt/dE2u7NyTY9jJG6YNiNthaD5XtUHVF2uCXZ15ng+z2hT3MVuxnUYhq6k1g==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/project-service@8.50.1': - resolution: {integrity: sha512-E1ur1MCVf+YiP89+o4Les/oBAVzmSbeRB0MQLfSlYtbWU17HPxZ6Bhs5iYmKZRALvEuBoXIZMOIRRc/P++Ortg==} + '@typescript-eslint/project-service@8.57.0': + resolution: {integrity: sha512-pR+dK0BlxCLxtWfaKQWtYr7MhKmzqZxuii+ZjuFlZlIGRZm22HnXFqa2eY+90MUz8/i80YJmzFGDUsi8dMOV5w==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/scope-manager@8.50.1': - resolution: {integrity: sha512-mfRx06Myt3T4vuoHaKi8ZWNTPdzKPNBhiblze5N50//TSHOAQQevl/aolqA/BcqqbJ88GUnLqjjcBc8EWdBcVw==} + '@typescript-eslint/scope-manager@8.57.0': + resolution: {integrity: sha512-nvExQqAHF01lUM66MskSaZulpPL5pgy5hI5RfrxviLgzZVffB5yYzw27uK/ft8QnKXI2X0LBrHJFr1TaZtAibw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/tsconfig-utils@8.50.1': - resolution: {integrity: sha512-ooHmotT/lCWLXi55G4mvaUF60aJa012QzvLK0Y+Mp4WdSt17QhMhWOaBWeGTFVkb2gDgBe19Cxy1elPXylslDw==} + '@typescript-eslint/tsconfig-utils@8.57.0': + resolution: {integrity: sha512-LtXRihc5ytjJIQEH+xqjB0+YgsV4/tW35XKX3GTZHpWtcC8SPkT/d4tqdf1cKtesryHm2bgp6l555NYcT2NLvA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/type-utils@8.50.1': - resolution: {integrity: sha512-7J3bf022QZE42tYMO6SL+6lTPKFk/WphhRPe9Tw/el+cEwzLz1Jjz2PX3GtGQVxooLDKeMVmMt7fWpYRdG5Etg==} + '@typescript-eslint/type-utils@8.57.0': + resolution: {integrity: sha512-yjgh7gmDcJ1+TcEg8x3uWQmn8ifvSupnPfjP21twPKrDP/pTHlEQgmKcitzF/rzPSmv7QjJ90vRpN4U+zoUjwQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/types@8.50.1': - resolution: {integrity: sha512-v5lFIS2feTkNyMhd7AucE/9j/4V9v5iIbpVRncjk/K0sQ6Sb+Np9fgYS/63n6nwqahHQvbmujeBL7mp07Q9mlA==} + '@typescript-eslint/types@8.57.0': + resolution: {integrity: sha512-dTLI8PEXhjUC7B9Kre+u0XznO696BhXcTlOn0/6kf1fHaQW8+VjJAVHJ3eTI14ZapTxdkOmc80HblPQLaEeJdg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/typescript-estree@8.50.1': - resolution: {integrity: sha512-woHPdW+0gj53aM+cxchymJCrh0cyS7BTIdcDxWUNsclr9VDkOSbqC13juHzxOmQ22dDkMZEpZB+3X1WpUvzgVQ==} + '@typescript-eslint/typescript-estree@8.57.0': + resolution: {integrity: sha512-m7faHcyVg0BT3VdYTlX8GdJEM7COexXxS6KqGopxdtkQRvBanK377QDHr4W/vIPAR+ah9+B/RclSW5ldVniO1Q==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/utils@8.50.1': - resolution: {integrity: sha512-lCLp8H1T9T7gPbEuJSnHwnSuO9mDf8mfK/Nion5mZmiEaQD9sWf9W4dfeFqRyqRjF06/kBuTmAqcs9sewM2NbQ==} + '@typescript-eslint/utils@8.57.0': + resolution: {integrity: sha512-5iIHvpD3CZe06riAsbNxxreP+MuYgVUsV0n4bwLH//VJmgtt54sQeY2GszntJ4BjYCpMzrfVh2SBnUQTtys2lQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/visitor-keys@8.50.1': - resolution: {integrity: sha512-IrDKrw7pCRUR94zeuCSUWQ+w8JEf5ZX5jl/e6AHGSLi1/zIr0lgutfn/7JpfCey+urpgQEdrZVYzCaVVKiTwhQ==} + '@typescript-eslint/visitor-keys@8.57.0': + resolution: {integrity: sha512-zm6xx8UT/Xy2oSr2ZXD0pZo7Jx2XsCoID2IUh9YSTFRu7z+WdwYTRk6LhUftm1crwqbuoF6I8zAFeCMw0YjwDg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@vitest/coverage-v8@4.0.16': - resolution: {integrity: sha512-2rNdjEIsPRzsdu6/9Eq0AYAzYdpP6Bx9cje9tL3FE5XzXRQF1fNU9pe/1yE8fCrS0HD+fBtt6gLPh6LI57tX7A==} + '@vitest/coverage-v8@4.0.18': + resolution: {integrity: sha512-7i+N2i0+ME+2JFZhfuz7Tg/FqKtilHjGyGvoHYQ6iLV0zahbsJ9sljC9OcFcPDbhYKCet+sG8SsVqlyGvPflZg==} peerDependencies: - '@vitest/browser': 4.0.16 - vitest: 4.0.16 + '@vitest/browser': 4.0.18 + vitest: 4.0.18 peerDependenciesMeta: '@vitest/browser': optional: true - '@vitest/expect@4.0.16': - resolution: {integrity: sha512-eshqULT2It7McaJkQGLkPjPjNph+uevROGuIMJdG3V+0BSR2w9u6J9Lwu+E8cK5TETlfou8GRijhafIMhXsimA==} + '@vitest/expect@4.0.18': + resolution: {integrity: sha512-8sCWUyckXXYvx4opfzVY03EOiYVxyNrHS5QxX3DAIi5dpJAAkyJezHCP77VMX4HKA2LDT/Jpfo8i2r5BE3GnQQ==} - '@vitest/mocker@4.0.16': - resolution: {integrity: sha512-yb6k4AZxJTB+q9ycAvsoxGn+j/po0UaPgajllBgt1PzoMAAmJGYFdDk0uCcRcxb3BrME34I6u8gHZTQlkqSZpg==} + '@vitest/mocker@4.0.18': + resolution: {integrity: sha512-HhVd0MDnzzsgevnOWCBj5Otnzobjy5wLBe4EdeeFGv8luMsGcYqDuFRMcttKWZA5vVO8RFjexVovXvAM4JoJDQ==} peerDependencies: msw: ^2.4.9 vite: ^6.0.0 || ^7.0.0-0 @@ -496,65 +531,57 @@ packages: vite: optional: true - '@vitest/pretty-format@4.0.16': - resolution: {integrity: sha512-eNCYNsSty9xJKi/UdVD8Ou16alu7AYiS2fCPRs0b1OdhJiV89buAXQLpTbe+X8V9L6qrs9CqyvU7OaAopJYPsA==} + '@vitest/pretty-format@4.0.18': + resolution: {integrity: sha512-P24GK3GulZWC5tz87ux0m8OADrQIUVDPIjjj65vBXYG17ZeU3qD7r+MNZ1RNv4l8CGU2vtTRqixrOi9fYk/yKw==} - '@vitest/runner@4.0.16': - resolution: {integrity: sha512-VWEDm5Wv9xEo80ctjORcTQRJ539EGPB3Pb9ApvVRAY1U/WkHXmmYISqU5E79uCwcW7xYUV38gwZD+RV755fu3Q==} + '@vitest/runner@4.0.18': + resolution: {integrity: sha512-rpk9y12PGa22Jg6g5M3UVVnTS7+zycIGk9ZNGN+m6tZHKQb7jrP7/77WfZy13Y/EUDd52NDsLRQhYKtv7XfPQw==} - '@vitest/snapshot@4.0.16': - resolution: {integrity: sha512-sf6NcrYhYBsSYefxnry+DR8n3UV4xWZwWxYbCJUt2YdvtqzSPR7VfGrY0zsv090DAbjFZsi7ZaMi1KnSRyK1XA==} + '@vitest/snapshot@4.0.18': + resolution: {integrity: sha512-PCiV0rcl7jKQjbgYqjtakly6T1uwv/5BQ9SwBLekVg/EaYeQFPiXcgrC2Y7vDMA8dM1SUEAEV82kgSQIlXNMvA==} - '@vitest/spy@4.0.16': - resolution: {integrity: sha512-4jIOWjKP0ZUaEmJm00E0cOBLU+5WE0BpeNr3XN6TEF05ltro6NJqHWxXD0kA8/Zc8Nh23AT8WQxwNG+WeROupw==} + '@vitest/spy@4.0.18': + resolution: {integrity: sha512-cbQt3PTSD7P2OARdVW3qWER5EGq7PHlvE+QfzSC0lbwO+xnt7+XH06ZzFjFRgzUX//JmpxrCu92VdwvEPlWSNw==} - '@vitest/utils@4.0.16': - resolution: {integrity: sha512-h8z9yYhV3e1LEfaQ3zdypIrnAg/9hguReGZoS7Gl0aBG5xgA410zBqECqmaF/+RkTggRsfnzc1XaAHA6bmUufA==} + '@vitest/utils@4.0.18': + resolution: {integrity: sha512-msMRKLMVLWygpK3u2Hybgi4MNjcYJvwTb0Ru09+fOyCXIgT5raYP041DRRdiJiI3k/2U6SEbAETB3YtBrUkCFA==} acorn-jsx@5.3.2: resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} peerDependencies: acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 - acorn@8.15.0: - resolution: {integrity: sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==} + acorn@8.16.0: + resolution: {integrity: sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==} engines: {node: '>=0.4.0'} hasBin: true - ajv@6.12.6: - resolution: {integrity: sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==} - - ansi-styles@4.3.0: - resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} - engines: {node: '>=8'} + ajv@6.14.0: + resolution: {integrity: sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==} any-promise@1.3.0: resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} - argparse@2.0.1: - resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} - assertion-error@2.0.1: resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} engines: {node: '>=12'} - ast-v8-to-istanbul@0.3.10: - resolution: {integrity: sha512-p4K7vMz2ZSk3wN8l5o3y2bJAoZXT3VuJI5OLTATY/01CYWumWvwkUw0SqDBnNq6IiTO3qDa1eSQDibAV8g7XOQ==} + ast-v8-to-istanbul@0.3.12: + resolution: {integrity: sha512-BRRC8VRZY2R4Z4lFIL35MwNXmwVqBityvOIwETtsCSwvjl0IdgFsy9NhdaA6j74nUdtJJlIypeRhpDam19Wq3g==} asynckit@0.4.0: resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} - axios@1.13.5: - resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==} + axios@1.13.6: + resolution: {integrity: sha512-ChTCHMouEe2kn713WHbQGcuYrr6fXTBiu460OTwWrWob16g1bXn4vtz07Ope7ewMozJAnEquLk5lWQWtBig9DQ==} - balanced-match@1.0.2: - resolution: {integrity: sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==} + balanced-match@4.0.4: + resolution: {integrity: sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==} + engines: {node: 18 || 20 || >=22} - brace-expansion@1.1.12: - resolution: {integrity: sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==} - - brace-expansion@2.0.2: - resolution: {integrity: sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==} + brace-expansion@5.0.4: + resolution: {integrity: sha512-h+DEnpVvxmfVefa4jFbCf5HdH5YMDXRsmKflpf1pILZWRFlTbJpxeU55nJl4Smt5HQaGzg1o6RHFPJaOqnmBDg==} + engines: {node: 18 || 20 || >=22} bundle-require@5.1.0: resolution: {integrity: sha512-3WrrOuZiyaaZPWiEt4G3+IffISVC9HYlWueJEBWED4ZH4aIAC2PnkdnuRrR94M+w6yGWn4AglWtJtBI8YqvgoA==} @@ -570,29 +597,14 @@ packages: resolution: {integrity: sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==} engines: {node: '>= 0.4'} - callsites@3.1.0: - resolution: {integrity: sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==} - engines: {node: '>=6'} - chai@6.2.2: resolution: {integrity: sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==} engines: {node: '>=18'} - chalk@4.1.2: - resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} - engines: {node: '>=10'} - chokidar@4.0.3: resolution: {integrity: sha512-Qgzu8kfBvo+cA4962jnP1KkS6Dop5NS6g7R5LFYJr4b8Ub94PPQXUksCw9PvXoeXPRRddRNC5C1JQUR2SMGtnA==} engines: {node: '>= 14.16.0'} - color-convert@2.0.1: - resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} - engines: {node: '>=7.0.0'} - - color-name@1.1.4: - resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} - combined-stream@1.0.8: resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} engines: {node: '>= 0.8'} @@ -601,9 +613,6 @@ packages: resolution: {integrity: sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==} engines: {node: '>= 6'} - concat-map@0.0.1: - resolution: {integrity: sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==} - confbox@0.1.8: resolution: {integrity: sha512-RMtmw0iFkeR4YV+fUOSucriAQNb9g8zFR52MWCtl+cCZOFRNL6zeB395vPzFhEjjn4fMxXudmELnl/KF/WrK6w==} @@ -654,8 +663,8 @@ packages: resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} engines: {node: '>= 0.4'} - esbuild@0.27.2: - resolution: {integrity: sha512-HyNQImnsOC7X9PMNaCIeAm4ISCQXs5a5YasTXVliKv4uuBo1dKrG0A+uQS8M5eXjVMnLg3WgXaKvprHlFJQffw==} + esbuild@0.27.3: + resolution: {integrity: sha512-8VwMnyGCONIs6cWue2IdpHxHnAjzxnw2Zr7MkVxB2vjmQ2ivqGFb4LEG3SMnv0Gb2F/G/2yA8zUaiL1gywDCCg==} engines: {node: '>=18'} hasBin: true @@ -663,21 +672,21 @@ packages: resolution: {integrity: sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==} engines: {node: '>=10'} - eslint-scope@8.4.0: - resolution: {integrity: sha512-sNXOfKCn74rt8RICKMvJS7XKV/Xk9kA7DyJr8mJik3S7Cwgy3qlkkmyS2uQB3jiJg6VNdZd/pDBJu0nvG2NlTg==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + eslint-scope@9.1.2: + resolution: {integrity: sha512-xS90H51cKw0jltxmvmHy2Iai1LIqrfbw57b79w/J7MfvDfkIkFZ+kj6zC3BjtUwh150HsSSdxXZcsuv72miDFQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} eslint-visitor-keys@3.4.3: resolution: {integrity: sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} - eslint-visitor-keys@4.2.1: - resolution: {integrity: sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + eslint-visitor-keys@5.0.1: + resolution: {integrity: sha512-tD40eHxA35h0PEIZNeIjkHoDR4YjjJp34biM0mDvplBe//mB+IHCqHDGV7pxF+7MklTvighcCPPZC7ynWyjdTA==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - eslint@9.39.2: - resolution: {integrity: sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + eslint@10.0.3: + resolution: {integrity: sha512-COV33RzXZkqhG9P2rZCFl9ZmJ7WL+gQSCRzE7RhkbclbQPtLAWReL7ysA0Sh4c8Im2U9ynybdR56PV0XcKvqaQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} hasBin: true peerDependencies: jiti: '*' @@ -685,12 +694,12 @@ packages: jiti: optional: true - espree@10.4.0: - resolution: {integrity: sha512-j6PAQ2uUr79PZhBjP5C5fhl8e39FmRnOjsD5lGnWrFU8i2G776tBK7+nP8KuQUTTyAZUwfQqXAgrVH5MbH9CYQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + espree@11.2.0: + resolution: {integrity: sha512-7p3DrVEIopW1B1avAGLuCSh1jubc01H2JHc8B4qqGblmg5gI9yumBgACjWo4JlIc04ufug4xJ3SQI8HkS/Rgzw==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - esquery@1.6.0: - resolution: {integrity: sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg==} + esquery@1.7.0: + resolution: {integrity: sha512-Ap6G0WQwcU/LHsvLwON1fAQX9Zp0A2Y6Y/cJBl9r/JbW90Zyg4/zbG6zzKa2OTALELarYHmKu0GhpM5EO+7T0g==} engines: {node: '>=0.10'} esrecurse@4.3.0: @@ -745,8 +754,8 @@ packages: resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} engines: {node: '>=16'} - flatted@3.3.3: - resolution: {integrity: sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==} + flatted@3.4.1: + resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==} follow-redirects@1.15.11: resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} @@ -781,10 +790,6 @@ packages: resolution: {integrity: sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==} engines: {node: '>=10.13.0'} - globals@14.0.0: - resolution: {integrity: sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==} - engines: {node: '>=18'} - gopd@1.2.0: resolution: {integrity: sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==} engines: {node: '>= 0.4'} @@ -816,10 +821,6 @@ packages: resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} engines: {node: '>= 4'} - import-fresh@3.3.1: - resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==} - engines: {node: '>=6'} - imurmurhash@0.1.4: resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==} engines: {node: '>=0.8.19'} @@ -843,10 +844,6 @@ packages: resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==} engines: {node: '>=10'} - istanbul-lib-source-maps@5.0.6: - resolution: {integrity: sha512-yg2d+Em4KizZC5niWhQaIomgf5WlL4vOOjZ5xGCmF8SnPE/mDWWXgvRExdcpCgh9lLRRa1/fSYp2ymmbJ1pI+A==} - engines: {node: '>=10'} - istanbul-reports@3.2.0: resolution: {integrity: sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==} engines: {node: '>=8'} @@ -855,12 +852,8 @@ packages: resolution: {integrity: sha512-34wB/Y7MW7bzjKRjUKTa46I2Z7eV62Rkhva+KkopW7Qvv/OSWBqvkSY7vusOPrNuZcUG3tApvdVgNB8POj3SPw==} engines: {node: '>=10'} - js-tokens@9.0.1: - resolution: {integrity: sha512-mxa9E9ITFOt0ban3j6L5MpjwegGz6lBQmM1IJkWeBZGcMxto50+eWdjC/52xDbS2vy0k7vIMK0Fe2wfL9OQSpQ==} - - js-yaml@4.1.1: - resolution: {integrity: sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==} - hasBin: true + js-tokens@10.0.0: + resolution: {integrity: sha512-lM/UBzQmfJRo9ABXbPWemivdCW8V2G8FHaHdypQaIy523snUjog0W71ayWXTjiR+ixeMyVHN2XcpnTd/liPg/Q==} json-buffer@3.0.1: resolution: {integrity: sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==} @@ -893,14 +886,11 @@ packages: resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==} engines: {node: '>=10'} - lodash.merge@4.6.2: - resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} - magic-string@0.30.21: resolution: {integrity: sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==} - magicast@0.5.1: - resolution: {integrity: sha512-xrHS24IxaLrvuo613F719wvOIv9xPHFWQHuvGUBmPnCA/3MQxKI3b+r7n1jAoDHmsbC5bRhTZYR77invLAxVnw==} + magicast@0.5.2: + resolution: {integrity: sha512-E3ZJh4J3S9KfwdjZhe2afj6R9lGIN5Pher1pF39UGrXRqq/VDaGVIGN13BjHd2u8B61hArAGOnso7nBOouW3TQ==} make-dir@4.0.0: resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==} @@ -918,15 +908,12 @@ packages: resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} engines: {node: '>= 0.6'} - minimatch@3.1.2: - resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==} + minimatch@10.2.4: + resolution: {integrity: sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==} + engines: {node: 18 || 20 || >=22} - minimatch@9.0.5: - resolution: {integrity: sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==} - engines: {node: '>=16 || 14 >=14.17'} - - mlly@1.8.0: - resolution: {integrity: sha512-l8D9ODSRWLe2KHJSifWGwBqpTZXIXTeo8mlKjY+E2HAakaTeNpqAyBZ8GSqLzHgw4XmHmC8whvpjJNMbFZN7/g==} + mlly@1.8.1: + resolution: {integrity: sha512-SnL6sNutTwRWWR/vcmCYHSADjiEesp5TGQQ0pXyLhW5IoeibRlF/CbSLailbB3CNqJUk9cVJ9dUDnbD7GrcHBQ==} ms@2.1.3: resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} @@ -961,10 +948,6 @@ packages: resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==} engines: {node: '>=10'} - parent-module@1.0.1: - resolution: {integrity: sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==} - engines: {node: '>=6'} - path-exists@4.0.0: resolution: {integrity: sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==} engines: {node: '>=8'} @@ -1008,8 +991,8 @@ packages: yaml: optional: true - postcss@8.5.6: - resolution: {integrity: sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==} + postcss@8.5.8: + resolution: {integrity: sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==} engines: {node: ^10 || ^12 || >=14} prelude-ls@1.2.1: @@ -1027,21 +1010,17 @@ packages: resolution: {integrity: sha512-GDhwkLfywWL2s6vEjyhri+eXmfH6j1L7JE27WhqLeYzoh/A3DBaYGEj2H/HFZCn/kMfim73FXxEJTw06WtxQwg==} engines: {node: '>= 14.18.0'} - resolve-from@4.0.0: - resolution: {integrity: sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==} - engines: {node: '>=4'} - resolve-from@5.0.0: resolution: {integrity: sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==} engines: {node: '>=8'} - rollup@4.54.0: - resolution: {integrity: sha512-3nk8Y3a9Ea8szgKhinMlGMhGMw89mqule3KWczxhIzqudyHdCIOHw8WJlj/r329fACjKLEh13ZSk7oE22kyeIw==} + rollup@4.59.0: + resolution: {integrity: sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==} engines: {node: '>=18.0.0', npm: '>=8.0.0'} hasBin: true - semver@7.7.3: - resolution: {integrity: sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==} + semver@7.7.4: + resolution: {integrity: sha512-vFKC2IEtQnVhpT78h1Yp8wzwrf8CM+MzKMHGJZfBtzhZNycRFnXsHk6E5TxIkkMsgNS7mdX3AGB7x2QM2di4lA==} engines: {node: '>=10'} hasBin: true @@ -1070,10 +1049,6 @@ packages: std-env@3.10.0: resolution: {integrity: sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==} - strip-json-comments@3.1.1: - resolution: {integrity: sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==} - engines: {node: '>=8'} - sucrase@3.35.1: resolution: {integrity: sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==} engines: {node: '>=16 || 14 >=14.17'} @@ -1112,8 +1087,8 @@ packages: resolution: {integrity: sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==} hasBin: true - ts-api-utils@2.1.0: - resolution: {integrity: sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==} + ts-api-utils@2.4.0: + resolution: {integrity: sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA==} engines: {node: '>=18.12'} peerDependencies: typescript: '>=4.8.4' @@ -1149,17 +1124,17 @@ packages: engines: {node: '>=14.17'} hasBin: true - ufo@1.6.1: - resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==} + ufo@1.6.3: + resolution: {integrity: sha512-yDJTmhydvl5lJzBmy/hyOAA0d+aqCBuwl818haVdYCRrWV84o7YyeVm4QlVHStqNrrJSTb6jKuFAVqAFsr+K3Q==} - undici-types@7.16.0: - resolution: {integrity: sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==} + undici-types@7.18.2: + resolution: {integrity: sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==} uri-js@4.4.1: resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} - vite@7.3.0: - resolution: {integrity: sha512-dZwN5L1VlUBewiP6H9s2+B3e3Jg96D0vzN+Ry73sOefebhYr9f94wwkMNN/9ouoU8pV1BqA1d1zGk8928cx0rg==} + vite@7.3.1: + resolution: {integrity: sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==} engines: {node: ^20.19.0 || >=22.12.0} hasBin: true peerDependencies: @@ -1198,18 +1173,18 @@ packages: yaml: optional: true - vitest@4.0.16: - resolution: {integrity: sha512-E4t7DJ9pESL6E3I8nFjPa4xGUd3PmiWDLsDztS2qXSJWfHtbQnwAWylaBvSNY48I3vr8PTqIZlyK8TE3V3CA4Q==} + vitest@4.0.18: + resolution: {integrity: sha512-hOQuK7h0FGKgBAas7v0mSAsnvrIgAvWmRFjmzpJ7SwFHH3g1k2u37JtYwOwmEKhK6ZO3v9ggDBBm0La1LCK4uQ==} engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} hasBin: true peerDependencies: '@edge-runtime/vm': '*' '@opentelemetry/api': ^1.9.0 '@types/node': ^20.0.0 || ^22.0.0 || >=24.0.0 - '@vitest/browser-playwright': 4.0.16 - '@vitest/browser-preview': 4.0.16 - '@vitest/browser-webdriverio': 4.0.16 - '@vitest/ui': 4.0.16 + '@vitest/browser-playwright': 4.0.18 + '@vitest/browser-preview': 4.0.18 + '@vitest/browser-webdriverio': 4.0.18 + '@vitest/ui': 4.0.18 happy-dom: '*' jsdom: '*' peerDependenciesMeta: @@ -1256,139 +1231,127 @@ snapshots: '@babel/helper-validator-identifier@7.28.5': {} - '@babel/parser@7.28.5': + '@babel/parser@7.29.0': dependencies: - '@babel/types': 7.28.5 + '@babel/types': 7.29.0 - '@babel/types@7.28.5': + '@babel/types@7.29.0': dependencies: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.28.5 '@bcoe/v8-coverage@1.0.2': {} - '@esbuild/aix-ppc64@0.27.2': + '@esbuild/aix-ppc64@0.27.3': optional: true - '@esbuild/android-arm64@0.27.2': + '@esbuild/android-arm64@0.27.3': optional: true - '@esbuild/android-arm@0.27.2': + '@esbuild/android-arm@0.27.3': optional: true - '@esbuild/android-x64@0.27.2': + '@esbuild/android-x64@0.27.3': optional: true - '@esbuild/darwin-arm64@0.27.2': + '@esbuild/darwin-arm64@0.27.3': optional: true - '@esbuild/darwin-x64@0.27.2': + '@esbuild/darwin-x64@0.27.3': optional: true - '@esbuild/freebsd-arm64@0.27.2': + '@esbuild/freebsd-arm64@0.27.3': optional: true - '@esbuild/freebsd-x64@0.27.2': + '@esbuild/freebsd-x64@0.27.3': optional: true - '@esbuild/linux-arm64@0.27.2': + '@esbuild/linux-arm64@0.27.3': optional: true - '@esbuild/linux-arm@0.27.2': + '@esbuild/linux-arm@0.27.3': optional: true - '@esbuild/linux-ia32@0.27.2': + '@esbuild/linux-ia32@0.27.3': optional: true - '@esbuild/linux-loong64@0.27.2': + '@esbuild/linux-loong64@0.27.3': optional: true - '@esbuild/linux-mips64el@0.27.2': + '@esbuild/linux-mips64el@0.27.3': optional: true - '@esbuild/linux-ppc64@0.27.2': + '@esbuild/linux-ppc64@0.27.3': optional: true - '@esbuild/linux-riscv64@0.27.2': + '@esbuild/linux-riscv64@0.27.3': optional: true - '@esbuild/linux-s390x@0.27.2': + '@esbuild/linux-s390x@0.27.3': optional: true - '@esbuild/linux-x64@0.27.2': + '@esbuild/linux-x64@0.27.3': optional: true - '@esbuild/netbsd-arm64@0.27.2': + '@esbuild/netbsd-arm64@0.27.3': optional: true - '@esbuild/netbsd-x64@0.27.2': + '@esbuild/netbsd-x64@0.27.3': optional: true - '@esbuild/openbsd-arm64@0.27.2': + '@esbuild/openbsd-arm64@0.27.3': optional: true - '@esbuild/openbsd-x64@0.27.2': + '@esbuild/openbsd-x64@0.27.3': optional: true - '@esbuild/openharmony-arm64@0.27.2': + '@esbuild/openharmony-arm64@0.27.3': optional: true - '@esbuild/sunos-x64@0.27.2': + '@esbuild/sunos-x64@0.27.3': optional: true - '@esbuild/win32-arm64@0.27.2': + '@esbuild/win32-arm64@0.27.3': optional: true - '@esbuild/win32-ia32@0.27.2': + '@esbuild/win32-ia32@0.27.3': optional: true - '@esbuild/win32-x64@0.27.2': + '@esbuild/win32-x64@0.27.3': optional: true - '@eslint-community/eslint-utils@4.9.0(eslint@9.39.2)': + '@eslint-community/eslint-utils@4.9.1(eslint@10.0.3)': dependencies: - eslint: 9.39.2 + eslint: 10.0.3 eslint-visitor-keys: 3.4.3 '@eslint-community/regexpp@4.12.2': {} - '@eslint/config-array@0.21.1': + '@eslint/config-array@0.23.3': dependencies: - '@eslint/object-schema': 2.1.7 + '@eslint/object-schema': 3.0.3 debug: 4.4.3 - minimatch: 3.1.2 + minimatch: 10.2.4 transitivePeerDependencies: - supports-color - '@eslint/config-helpers@0.4.2': + '@eslint/config-helpers@0.5.3': dependencies: - '@eslint/core': 0.17.0 + '@eslint/core': 1.1.1 - '@eslint/core@0.17.0': + '@eslint/core@1.1.1': dependencies: '@types/json-schema': 7.0.15 - '@eslint/eslintrc@3.3.3': + '@eslint/js@10.0.1(eslint@10.0.3)': + optionalDependencies: + eslint: 10.0.3 + + '@eslint/object-schema@3.0.3': {} + + '@eslint/plugin-kit@0.6.1': dependencies: - ajv: 6.12.6 - debug: 4.4.3 - espree: 10.4.0 - globals: 14.0.0 - ignore: 5.3.2 - import-fresh: 3.3.1 - js-yaml: 4.1.1 - minimatch: 3.1.2 - strip-json-comments: 3.1.1 - transitivePeerDependencies: - - supports-color - - '@eslint/js@9.39.2': {} - - '@eslint/object-schema@2.1.7': {} - - '@eslint/plugin-kit@0.4.1': - dependencies: - '@eslint/core': 0.17.0 + '@eslint/core': 1.1.1 levn: 0.4.1 '@humanfs/core@0.19.1': {} @@ -1416,70 +1379,79 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.5 - '@rollup/rollup-android-arm-eabi@4.54.0': + '@rollup/rollup-android-arm-eabi@4.59.0': optional: true - '@rollup/rollup-android-arm64@4.54.0': + '@rollup/rollup-android-arm64@4.59.0': optional: true - '@rollup/rollup-darwin-arm64@4.54.0': + '@rollup/rollup-darwin-arm64@4.59.0': optional: true - '@rollup/rollup-darwin-x64@4.54.0': + '@rollup/rollup-darwin-x64@4.59.0': optional: true - '@rollup/rollup-freebsd-arm64@4.54.0': + '@rollup/rollup-freebsd-arm64@4.59.0': optional: true - '@rollup/rollup-freebsd-x64@4.54.0': + '@rollup/rollup-freebsd-x64@4.59.0': optional: true - '@rollup/rollup-linux-arm-gnueabihf@4.54.0': + '@rollup/rollup-linux-arm-gnueabihf@4.59.0': optional: true - '@rollup/rollup-linux-arm-musleabihf@4.54.0': + '@rollup/rollup-linux-arm-musleabihf@4.59.0': optional: true - '@rollup/rollup-linux-arm64-gnu@4.54.0': + '@rollup/rollup-linux-arm64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-arm64-musl@4.54.0': + '@rollup/rollup-linux-arm64-musl@4.59.0': optional: true - '@rollup/rollup-linux-loong64-gnu@4.54.0': + '@rollup/rollup-linux-loong64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-ppc64-gnu@4.54.0': + '@rollup/rollup-linux-loong64-musl@4.59.0': optional: true - '@rollup/rollup-linux-riscv64-gnu@4.54.0': + '@rollup/rollup-linux-ppc64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-riscv64-musl@4.54.0': + '@rollup/rollup-linux-ppc64-musl@4.59.0': optional: true - '@rollup/rollup-linux-s390x-gnu@4.54.0': + '@rollup/rollup-linux-riscv64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-x64-gnu@4.54.0': + '@rollup/rollup-linux-riscv64-musl@4.59.0': optional: true - '@rollup/rollup-linux-x64-musl@4.54.0': + '@rollup/rollup-linux-s390x-gnu@4.59.0': optional: true - '@rollup/rollup-openharmony-arm64@4.54.0': + '@rollup/rollup-linux-x64-gnu@4.59.0': optional: true - '@rollup/rollup-win32-arm64-msvc@4.54.0': + '@rollup/rollup-linux-x64-musl@4.59.0': optional: true - '@rollup/rollup-win32-ia32-msvc@4.54.0': + '@rollup/rollup-openbsd-x64@4.59.0': optional: true - '@rollup/rollup-win32-x64-gnu@4.54.0': + '@rollup/rollup-openharmony-arm64@4.59.0': optional: true - '@rollup/rollup-win32-x64-msvc@4.54.0': + '@rollup/rollup-win32-arm64-msvc@4.59.0': + optional: true + + '@rollup/rollup-win32-ia32-msvc@4.59.0': + optional: true + + '@rollup/rollup-win32-x64-gnu@4.59.0': + optional: true + + '@rollup/rollup-win32-x64-msvc@4.59.0': optional: true '@standard-schema/spec@1.1.0': {} @@ -1491,193 +1463,186 @@ snapshots: '@types/deep-eql@4.0.2': {} + '@types/esrecurse@4.3.1': {} + '@types/estree@1.0.8': {} '@types/json-schema@7.0.15': {} - '@types/node@25.0.3': + '@types/node@25.4.0': dependencies: - undici-types: 7.16.0 + undici-types: 7.18.2 - '@typescript-eslint/eslint-plugin@8.50.1(@typescript-eslint/parser@8.50.1(eslint@9.39.2)(typescript@5.9.3))(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/eslint-plugin@8.57.0(@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3))(eslint@10.0.3)(typescript@5.9.3)': dependencies: '@eslint-community/regexpp': 4.12.2 - '@typescript-eslint/parser': 8.50.1(eslint@9.39.2)(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.50.1 - '@typescript-eslint/type-utils': 8.50.1(eslint@9.39.2)(typescript@5.9.3) - '@typescript-eslint/utils': 8.50.1(eslint@9.39.2)(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.50.1 - eslint: 9.39.2 + '@typescript-eslint/parser': 8.57.0(eslint@10.0.3)(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.0 + '@typescript-eslint/type-utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.0 + eslint: 10.0.3 ignore: 7.0.5 natural-compare: 1.4.0 - ts-api-utils: 2.1.0(typescript@5.9.3) + ts-api-utils: 2.4.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/parser@8.50.1(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3)': dependencies: - '@typescript-eslint/scope-manager': 8.50.1 - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/typescript-estree': 8.50.1(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.50.1 + '@typescript-eslint/scope-manager': 8.57.0 + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.0 debug: 4.4.3 - eslint: 9.39.2 + eslint: 10.0.3 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/project-service@8.50.1(typescript@5.9.3)': + '@typescript-eslint/project-service@8.57.0(typescript@5.9.3)': dependencies: - '@typescript-eslint/tsconfig-utils': 8.50.1(typescript@5.9.3) - '@typescript-eslint/types': 8.50.1 + '@typescript-eslint/tsconfig-utils': 8.57.0(typescript@5.9.3) + '@typescript-eslint/types': 8.57.0 debug: 4.4.3 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/scope-manager@8.50.1': + '@typescript-eslint/scope-manager@8.57.0': dependencies: - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/visitor-keys': 8.50.1 + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/visitor-keys': 8.57.0 - '@typescript-eslint/tsconfig-utils@8.50.1(typescript@5.9.3)': + '@typescript-eslint/tsconfig-utils@8.57.0(typescript@5.9.3)': dependencies: typescript: 5.9.3 - '@typescript-eslint/type-utils@8.50.1(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/type-utils@8.57.0(eslint@10.0.3)(typescript@5.9.3)': dependencies: - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/typescript-estree': 8.50.1(typescript@5.9.3) - '@typescript-eslint/utils': 8.50.1(eslint@9.39.2)(typescript@5.9.3) + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) debug: 4.4.3 - eslint: 9.39.2 - ts-api-utils: 2.1.0(typescript@5.9.3) + eslint: 10.0.3 + ts-api-utils: 2.4.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/types@8.50.1': {} + '@typescript-eslint/types@8.57.0': {} - '@typescript-eslint/typescript-estree@8.50.1(typescript@5.9.3)': + '@typescript-eslint/typescript-estree@8.57.0(typescript@5.9.3)': dependencies: - '@typescript-eslint/project-service': 8.50.1(typescript@5.9.3) - '@typescript-eslint/tsconfig-utils': 8.50.1(typescript@5.9.3) - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/visitor-keys': 8.50.1 + '@typescript-eslint/project-service': 8.57.0(typescript@5.9.3) + '@typescript-eslint/tsconfig-utils': 8.57.0(typescript@5.9.3) + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/visitor-keys': 8.57.0 debug: 4.4.3 - minimatch: 9.0.5 - semver: 7.7.3 + minimatch: 10.2.4 + semver: 7.7.4 tinyglobby: 0.2.15 - ts-api-utils: 2.1.0(typescript@5.9.3) + ts-api-utils: 2.4.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/utils@8.50.1(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/utils@8.57.0(eslint@10.0.3)(typescript@5.9.3)': dependencies: - '@eslint-community/eslint-utils': 4.9.0(eslint@9.39.2) - '@typescript-eslint/scope-manager': 8.50.1 - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/typescript-estree': 8.50.1(typescript@5.9.3) - eslint: 9.39.2 + '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3) + '@typescript-eslint/scope-manager': 8.57.0 + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) + eslint: 10.0.3 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/visitor-keys@8.50.1': + '@typescript-eslint/visitor-keys@8.57.0': dependencies: - '@typescript-eslint/types': 8.50.1 - eslint-visitor-keys: 4.2.1 + '@typescript-eslint/types': 8.57.0 + eslint-visitor-keys: 5.0.1 - '@vitest/coverage-v8@4.0.16(vitest@4.0.16(@types/node@25.0.3))': + '@vitest/coverage-v8@4.0.18(vitest@4.0.18(@types/node@25.4.0))': dependencies: '@bcoe/v8-coverage': 1.0.2 - '@vitest/utils': 4.0.16 - ast-v8-to-istanbul: 0.3.10 + '@vitest/utils': 4.0.18 + ast-v8-to-istanbul: 0.3.12 istanbul-lib-coverage: 3.2.2 istanbul-lib-report: 3.0.1 - istanbul-lib-source-maps: 5.0.6 istanbul-reports: 3.2.0 - magicast: 0.5.1 + magicast: 0.5.2 obug: 2.1.1 std-env: 3.10.0 tinyrainbow: 3.0.3 - vitest: 4.0.16(@types/node@25.0.3) - transitivePeerDependencies: - - supports-color + vitest: 4.0.18(@types/node@25.4.0) - '@vitest/expect@4.0.16': + '@vitest/expect@4.0.18': dependencies: '@standard-schema/spec': 1.1.0 '@types/chai': 5.2.3 - '@vitest/spy': 4.0.16 - '@vitest/utils': 4.0.16 + '@vitest/spy': 4.0.18 + '@vitest/utils': 4.0.18 chai: 6.2.2 tinyrainbow: 3.0.3 - '@vitest/mocker@4.0.16(vite@7.3.0(@types/node@25.0.3))': + '@vitest/mocker@4.0.18(vite@7.3.1(@types/node@25.4.0))': dependencies: - '@vitest/spy': 4.0.16 + '@vitest/spy': 4.0.18 estree-walker: 3.0.3 magic-string: 0.30.21 optionalDependencies: - vite: 7.3.0(@types/node@25.0.3) + vite: 7.3.1(@types/node@25.4.0) - '@vitest/pretty-format@4.0.16': + '@vitest/pretty-format@4.0.18': dependencies: tinyrainbow: 3.0.3 - '@vitest/runner@4.0.16': + '@vitest/runner@4.0.18': dependencies: - '@vitest/utils': 4.0.16 + '@vitest/utils': 4.0.18 pathe: 2.0.3 - '@vitest/snapshot@4.0.16': + '@vitest/snapshot@4.0.18': dependencies: - '@vitest/pretty-format': 4.0.16 + '@vitest/pretty-format': 4.0.18 magic-string: 0.30.21 pathe: 2.0.3 - '@vitest/spy@4.0.16': {} + '@vitest/spy@4.0.18': {} - '@vitest/utils@4.0.16': + '@vitest/utils@4.0.18': dependencies: - '@vitest/pretty-format': 4.0.16 + '@vitest/pretty-format': 4.0.18 tinyrainbow: 3.0.3 - acorn-jsx@5.3.2(acorn@8.15.0): + acorn-jsx@5.3.2(acorn@8.16.0): dependencies: - acorn: 8.15.0 + acorn: 8.16.0 - acorn@8.15.0: {} + acorn@8.16.0: {} - ajv@6.12.6: + ajv@6.14.0: dependencies: fast-deep-equal: 3.1.3 fast-json-stable-stringify: 2.1.0 json-schema-traverse: 0.4.1 uri-js: 4.4.1 - ansi-styles@4.3.0: - dependencies: - color-convert: 2.0.1 - any-promise@1.3.0: {} - argparse@2.0.1: {} - assertion-error@2.0.1: {} - ast-v8-to-istanbul@0.3.10: + ast-v8-to-istanbul@0.3.12: dependencies: '@jridgewell/trace-mapping': 0.3.31 estree-walker: 3.0.3 - js-tokens: 9.0.1 + js-tokens: 10.0.0 asynckit@0.4.0: {} - axios@1.13.5: + axios@1.13.6: dependencies: follow-redirects: 1.15.11 form-data: 4.0.5 @@ -1685,20 +1650,15 @@ snapshots: transitivePeerDependencies: - debug - balanced-match@1.0.2: {} + balanced-match@4.0.4: {} - brace-expansion@1.1.12: + brace-expansion@5.0.4: dependencies: - balanced-match: 1.0.2 - concat-map: 0.0.1 + balanced-match: 4.0.4 - brace-expansion@2.0.2: + bundle-require@5.1.0(esbuild@0.27.3): dependencies: - balanced-match: 1.0.2 - - bundle-require@5.1.0(esbuild@0.27.2): - dependencies: - esbuild: 0.27.2 + esbuild: 0.27.3 load-tsconfig: 0.2.5 cac@6.7.14: {} @@ -1708,33 +1668,18 @@ snapshots: es-errors: 1.3.0 function-bind: 1.1.2 - callsites@3.1.0: {} - chai@6.2.2: {} - chalk@4.1.2: - dependencies: - ansi-styles: 4.3.0 - supports-color: 7.2.0 - chokidar@4.0.3: dependencies: readdirp: 4.1.2 - color-convert@2.0.1: - dependencies: - color-name: 1.1.4 - - color-name@1.1.4: {} - combined-stream@1.0.8: dependencies: delayed-stream: 1.0.0 commander@4.1.1: {} - concat-map@0.0.1: {} - confbox@0.1.8: {} consola@3.4.2: {} @@ -1776,69 +1721,68 @@ snapshots: has-tostringtag: 1.0.2 hasown: 2.0.2 - esbuild@0.27.2: + esbuild@0.27.3: optionalDependencies: - '@esbuild/aix-ppc64': 0.27.2 - '@esbuild/android-arm': 0.27.2 - '@esbuild/android-arm64': 0.27.2 - '@esbuild/android-x64': 0.27.2 - '@esbuild/darwin-arm64': 0.27.2 - '@esbuild/darwin-x64': 0.27.2 - '@esbuild/freebsd-arm64': 0.27.2 - '@esbuild/freebsd-x64': 0.27.2 - '@esbuild/linux-arm': 0.27.2 - '@esbuild/linux-arm64': 0.27.2 - '@esbuild/linux-ia32': 0.27.2 - '@esbuild/linux-loong64': 0.27.2 - '@esbuild/linux-mips64el': 0.27.2 - '@esbuild/linux-ppc64': 0.27.2 - '@esbuild/linux-riscv64': 0.27.2 - '@esbuild/linux-s390x': 0.27.2 - '@esbuild/linux-x64': 0.27.2 - '@esbuild/netbsd-arm64': 0.27.2 - '@esbuild/netbsd-x64': 0.27.2 - '@esbuild/openbsd-arm64': 0.27.2 - '@esbuild/openbsd-x64': 0.27.2 - '@esbuild/openharmony-arm64': 0.27.2 - '@esbuild/sunos-x64': 0.27.2 - '@esbuild/win32-arm64': 0.27.2 - '@esbuild/win32-ia32': 0.27.2 - '@esbuild/win32-x64': 0.27.2 + '@esbuild/aix-ppc64': 0.27.3 + '@esbuild/android-arm': 0.27.3 + '@esbuild/android-arm64': 0.27.3 + '@esbuild/android-x64': 0.27.3 + '@esbuild/darwin-arm64': 0.27.3 + '@esbuild/darwin-x64': 0.27.3 + '@esbuild/freebsd-arm64': 0.27.3 + '@esbuild/freebsd-x64': 0.27.3 + '@esbuild/linux-arm': 0.27.3 + '@esbuild/linux-arm64': 0.27.3 + '@esbuild/linux-ia32': 0.27.3 + '@esbuild/linux-loong64': 0.27.3 + '@esbuild/linux-mips64el': 0.27.3 + '@esbuild/linux-ppc64': 0.27.3 + '@esbuild/linux-riscv64': 0.27.3 + '@esbuild/linux-s390x': 0.27.3 + '@esbuild/linux-x64': 0.27.3 + '@esbuild/netbsd-arm64': 0.27.3 + '@esbuild/netbsd-x64': 0.27.3 + '@esbuild/openbsd-arm64': 0.27.3 + '@esbuild/openbsd-x64': 0.27.3 + '@esbuild/openharmony-arm64': 0.27.3 + '@esbuild/sunos-x64': 0.27.3 + '@esbuild/win32-arm64': 0.27.3 + '@esbuild/win32-ia32': 0.27.3 + '@esbuild/win32-x64': 0.27.3 escape-string-regexp@4.0.0: {} - eslint-scope@8.4.0: + eslint-scope@9.1.2: dependencies: + '@types/esrecurse': 4.3.1 + '@types/estree': 1.0.8 esrecurse: 4.3.0 estraverse: 5.3.0 eslint-visitor-keys@3.4.3: {} - eslint-visitor-keys@4.2.1: {} + eslint-visitor-keys@5.0.1: {} - eslint@9.39.2: + eslint@10.0.3: dependencies: - '@eslint-community/eslint-utils': 4.9.0(eslint@9.39.2) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3) '@eslint-community/regexpp': 4.12.2 - '@eslint/config-array': 0.21.1 - '@eslint/config-helpers': 0.4.2 - '@eslint/core': 0.17.0 - '@eslint/eslintrc': 3.3.3 - '@eslint/js': 9.39.2 - '@eslint/plugin-kit': 0.4.1 + '@eslint/config-array': 0.23.3 + '@eslint/config-helpers': 0.5.3 + '@eslint/core': 1.1.1 + '@eslint/plugin-kit': 0.6.1 '@humanfs/node': 0.16.7 '@humanwhocodes/module-importer': 1.0.1 '@humanwhocodes/retry': 0.4.3 '@types/estree': 1.0.8 - ajv: 6.12.6 - chalk: 4.1.2 + ajv: 6.14.0 cross-spawn: 7.0.6 debug: 4.4.3 escape-string-regexp: 4.0.0 - eslint-scope: 8.4.0 - eslint-visitor-keys: 4.2.1 - espree: 10.4.0 - esquery: 1.6.0 + eslint-scope: 9.1.2 + eslint-visitor-keys: 5.0.1 + espree: 11.2.0 + esquery: 1.7.0 esutils: 2.0.3 fast-deep-equal: 3.1.3 file-entry-cache: 8.0.0 @@ -1848,20 +1792,19 @@ snapshots: imurmurhash: 0.1.4 is-glob: 4.0.3 json-stable-stringify-without-jsonify: 1.0.1 - lodash.merge: 4.6.2 - minimatch: 3.1.2 + minimatch: 10.2.4 natural-compare: 1.4.0 optionator: 0.9.4 transitivePeerDependencies: - supports-color - espree@10.4.0: + espree@11.2.0: dependencies: - acorn: 8.15.0 - acorn-jsx: 5.3.2(acorn@8.15.0) - eslint-visitor-keys: 4.2.1 + acorn: 8.16.0 + acorn-jsx: 5.3.2(acorn@8.16.0) + eslint-visitor-keys: 5.0.1 - esquery@1.6.0: + esquery@1.7.0: dependencies: estraverse: 5.3.0 @@ -1901,15 +1844,15 @@ snapshots: fix-dts-default-cjs-exports@1.0.1: dependencies: magic-string: 0.30.21 - mlly: 1.8.0 - rollup: 4.54.0 + mlly: 1.8.1 + rollup: 4.59.0 flat-cache@4.0.1: dependencies: - flatted: 3.3.3 + flatted: 3.4.1 keyv: 4.5.4 - flatted@3.3.3: {} + flatted@3.4.1: {} follow-redirects@1.15.11: {} @@ -1948,8 +1891,6 @@ snapshots: dependencies: is-glob: 4.0.3 - globals@14.0.0: {} - gopd@1.2.0: {} has-flag@4.0.0: {} @@ -1970,11 +1911,6 @@ snapshots: ignore@7.0.5: {} - import-fresh@3.3.1: - dependencies: - parent-module: 1.0.1 - resolve-from: 4.0.0 - imurmurhash@0.1.4: {} is-extglob@2.1.1: {} @@ -1993,14 +1929,6 @@ snapshots: make-dir: 4.0.0 supports-color: 7.2.0 - istanbul-lib-source-maps@5.0.6: - dependencies: - '@jridgewell/trace-mapping': 0.3.31 - debug: 4.4.3 - istanbul-lib-coverage: 3.2.2 - transitivePeerDependencies: - - supports-color - istanbul-reports@3.2.0: dependencies: html-escaper: 2.0.2 @@ -2008,11 +1936,7 @@ snapshots: joycon@3.1.1: {} - js-tokens@9.0.1: {} - - js-yaml@4.1.1: - dependencies: - argparse: 2.0.1 + js-tokens@10.0.0: {} json-buffer@3.0.1: {} @@ -2039,21 +1963,19 @@ snapshots: dependencies: p-locate: 5.0.0 - lodash.merge@4.6.2: {} - magic-string@0.30.21: dependencies: '@jridgewell/sourcemap-codec': 1.5.5 - magicast@0.5.1: + magicast@0.5.2: dependencies: - '@babel/parser': 7.28.5 - '@babel/types': 7.28.5 + '@babel/parser': 7.29.0 + '@babel/types': 7.29.0 source-map-js: 1.2.1 make-dir@4.0.0: dependencies: - semver: 7.7.3 + semver: 7.7.4 math-intrinsics@1.1.0: {} @@ -2063,20 +1985,16 @@ snapshots: dependencies: mime-db: 1.52.0 - minimatch@3.1.2: + minimatch@10.2.4: dependencies: - brace-expansion: 1.1.12 + brace-expansion: 5.0.4 - minimatch@9.0.5: + mlly@1.8.1: dependencies: - brace-expansion: 2.0.2 - - mlly@1.8.0: - dependencies: - acorn: 8.15.0 + acorn: 8.16.0 pathe: 2.0.3 pkg-types: 1.3.1 - ufo: 1.6.1 + ufo: 1.6.3 ms@2.1.3: {} @@ -2111,10 +2029,6 @@ snapshots: dependencies: p-limit: 3.1.0 - parent-module@1.0.1: - dependencies: - callsites: 3.1.0 - path-exists@4.0.0: {} path-key@3.1.1: {} @@ -2130,16 +2044,16 @@ snapshots: pkg-types@1.3.1: dependencies: confbox: 0.1.8 - mlly: 1.8.0 + mlly: 1.8.1 pathe: 2.0.3 - postcss-load-config@6.0.1(postcss@8.5.6): + postcss-load-config@6.0.1(postcss@8.5.8): dependencies: lilconfig: 3.1.3 optionalDependencies: - postcss: 8.5.6 + postcss: 8.5.8 - postcss@8.5.6: + postcss@8.5.8: dependencies: nanoid: 3.3.11 picocolors: 1.1.1 @@ -2153,39 +2067,40 @@ snapshots: readdirp@4.1.2: {} - resolve-from@4.0.0: {} - resolve-from@5.0.0: {} - rollup@4.54.0: + rollup@4.59.0: dependencies: '@types/estree': 1.0.8 optionalDependencies: - '@rollup/rollup-android-arm-eabi': 4.54.0 - '@rollup/rollup-android-arm64': 4.54.0 - '@rollup/rollup-darwin-arm64': 4.54.0 - '@rollup/rollup-darwin-x64': 4.54.0 - '@rollup/rollup-freebsd-arm64': 4.54.0 - '@rollup/rollup-freebsd-x64': 4.54.0 - '@rollup/rollup-linux-arm-gnueabihf': 4.54.0 - '@rollup/rollup-linux-arm-musleabihf': 4.54.0 - '@rollup/rollup-linux-arm64-gnu': 4.54.0 - '@rollup/rollup-linux-arm64-musl': 4.54.0 - '@rollup/rollup-linux-loong64-gnu': 4.54.0 - '@rollup/rollup-linux-ppc64-gnu': 4.54.0 - '@rollup/rollup-linux-riscv64-gnu': 4.54.0 - '@rollup/rollup-linux-riscv64-musl': 4.54.0 - '@rollup/rollup-linux-s390x-gnu': 4.54.0 - '@rollup/rollup-linux-x64-gnu': 4.54.0 - '@rollup/rollup-linux-x64-musl': 4.54.0 - '@rollup/rollup-openharmony-arm64': 4.54.0 - '@rollup/rollup-win32-arm64-msvc': 4.54.0 - '@rollup/rollup-win32-ia32-msvc': 4.54.0 - '@rollup/rollup-win32-x64-gnu': 4.54.0 - '@rollup/rollup-win32-x64-msvc': 4.54.0 + '@rollup/rollup-android-arm-eabi': 4.59.0 + '@rollup/rollup-android-arm64': 4.59.0 + '@rollup/rollup-darwin-arm64': 4.59.0 + '@rollup/rollup-darwin-x64': 4.59.0 + '@rollup/rollup-freebsd-arm64': 4.59.0 + '@rollup/rollup-freebsd-x64': 4.59.0 + '@rollup/rollup-linux-arm-gnueabihf': 4.59.0 + '@rollup/rollup-linux-arm-musleabihf': 4.59.0 + '@rollup/rollup-linux-arm64-gnu': 4.59.0 + '@rollup/rollup-linux-arm64-musl': 4.59.0 + '@rollup/rollup-linux-loong64-gnu': 4.59.0 + '@rollup/rollup-linux-loong64-musl': 4.59.0 + '@rollup/rollup-linux-ppc64-gnu': 4.59.0 + '@rollup/rollup-linux-ppc64-musl': 4.59.0 + '@rollup/rollup-linux-riscv64-gnu': 4.59.0 + '@rollup/rollup-linux-riscv64-musl': 4.59.0 + '@rollup/rollup-linux-s390x-gnu': 4.59.0 + '@rollup/rollup-linux-x64-gnu': 4.59.0 + '@rollup/rollup-linux-x64-musl': 4.59.0 + '@rollup/rollup-openbsd-x64': 4.59.0 + '@rollup/rollup-openharmony-arm64': 4.59.0 + '@rollup/rollup-win32-arm64-msvc': 4.59.0 + '@rollup/rollup-win32-ia32-msvc': 4.59.0 + '@rollup/rollup-win32-x64-gnu': 4.59.0 + '@rollup/rollup-win32-x64-msvc': 4.59.0 fsevents: 2.3.3 - semver@7.7.3: {} + semver@7.7.4: {} shebang-command@2.0.0: dependencies: @@ -2203,8 +2118,6 @@ snapshots: std-env@3.10.0: {} - strip-json-comments@3.1.1: {} - sucrase@3.35.1: dependencies: '@jridgewell/gen-mapping': 0.3.13 @@ -2242,33 +2155,33 @@ snapshots: tree-kill@1.2.2: {} - ts-api-utils@2.1.0(typescript@5.9.3): + ts-api-utils@2.4.0(typescript@5.9.3): dependencies: typescript: 5.9.3 ts-interface-checker@0.1.13: {} - tsup@8.5.1(postcss@8.5.6)(typescript@5.9.3): + tsup@8.5.1(postcss@8.5.8)(typescript@5.9.3): dependencies: - bundle-require: 5.1.0(esbuild@0.27.2) + bundle-require: 5.1.0(esbuild@0.27.3) cac: 6.7.14 chokidar: 4.0.3 consola: 3.4.2 debug: 4.4.3 - esbuild: 0.27.2 + esbuild: 0.27.3 fix-dts-default-cjs-exports: 1.0.1 joycon: 3.1.1 picocolors: 1.1.1 - postcss-load-config: 6.0.1(postcss@8.5.6) + postcss-load-config: 6.0.1(postcss@8.5.8) resolve-from: 5.0.0 - rollup: 4.54.0 + rollup: 4.59.0 source-map: 0.7.6 sucrase: 3.35.1 tinyexec: 0.3.2 tinyglobby: 0.2.15 tree-kill: 1.2.2 optionalDependencies: - postcss: 8.5.6 + postcss: 8.5.8 typescript: 5.9.3 transitivePeerDependencies: - jiti @@ -2282,35 +2195,35 @@ snapshots: typescript@5.9.3: {} - ufo@1.6.1: {} + ufo@1.6.3: {} - undici-types@7.16.0: {} + undici-types@7.18.2: {} uri-js@4.4.1: dependencies: punycode: 2.3.1 - vite@7.3.0(@types/node@25.0.3): + vite@7.3.1(@types/node@25.4.0): dependencies: - esbuild: 0.27.2 + esbuild: 0.27.3 fdir: 6.5.0(picomatch@4.0.3) picomatch: 4.0.3 - postcss: 8.5.6 - rollup: 4.54.0 + postcss: 8.5.8 + rollup: 4.59.0 tinyglobby: 0.2.15 optionalDependencies: - '@types/node': 25.0.3 + '@types/node': 25.4.0 fsevents: 2.3.3 - vitest@4.0.16(@types/node@25.0.3): + vitest@4.0.18(@types/node@25.4.0): dependencies: - '@vitest/expect': 4.0.16 - '@vitest/mocker': 4.0.16(vite@7.3.0(@types/node@25.0.3)) - '@vitest/pretty-format': 4.0.16 - '@vitest/runner': 4.0.16 - '@vitest/snapshot': 4.0.16 - '@vitest/spy': 4.0.16 - '@vitest/utils': 4.0.16 + '@vitest/expect': 4.0.18 + '@vitest/mocker': 4.0.18(vite@7.3.1(@types/node@25.4.0)) + '@vitest/pretty-format': 4.0.18 + '@vitest/runner': 4.0.18 + '@vitest/snapshot': 4.0.18 + '@vitest/spy': 4.0.18 + '@vitest/utils': 4.0.18 es-module-lexer: 1.7.0 expect-type: 1.3.0 magic-string: 0.30.21 @@ -2322,10 +2235,10 @@ snapshots: tinyexec: 1.0.2 tinyglobby: 0.2.15 tinyrainbow: 3.0.3 - vite: 7.3.0(@types/node@25.0.3) + vite: 7.3.1(@types/node@25.4.0) why-is-node-running: 2.3.0 optionalDependencies: - '@types/node': 25.0.3 + '@types/node': 25.4.0 transitivePeerDependencies: - jiti - less diff --git a/web/AGENTS.md b/web/AGENTS.md index 5dd41b8a3c..71000eafdb 100644 --- a/web/AGENTS.md +++ b/web/AGENTS.md @@ -2,6 +2,12 @@ - Refer to the `./docs/test.md` and `./docs/lint.md` for detailed frontend workflow instructions. +## Overlay Components (Mandatory) + +- `./docs/overlay-migration.md` is the source of truth for overlay-related work. +- In new or modified code, use only overlay primitives from `@/app/components/base/ui/*`. +- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them and keep the allowlist shrinking (never expanding). + ## Automated Test Generation - Use `./docs/test.md` as the canonical instruction set for generating frontend automated tests. diff --git a/web/Dockerfile b/web/Dockerfile index 9b24f9ea0a..b54bae706c 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -35,7 +35,7 @@ COPY --from=packages /app/web/ . COPY . . ENV NODE_OPTIONS="--max-old-space-size=4096" -RUN pnpm build:docker +RUN pnpm build # production stage diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts index 221ba2808f..71f5b009d3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts @@ -5,7 +5,7 @@ export const docURL = { [TracingProvider.phoenix]: 'https://docs.arize.com/phoenix', [TracingProvider.langSmith]: 'https://docs.smith.langchain.com/', [TracingProvider.langfuse]: 'https://docs.langfuse.com', - [TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions', + [TracingProvider.opik]: 'https://www.comet.com/docs/opik/integrations/dify', [TracingProvider.weave]: 'https://weave-docs.wandb.ai/', [TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680', [TracingProvider.mlflow]: 'https://mlflow.org/docs/latest/genai/', diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 0a17822187..9bd32d2576 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -11,7 +11,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import ImageInput from '@/app/components/base/app-icon-picker/ImageInput' import getCroppedImg from '@/app/components/base/app-icon-picker/utils' -import Avatar from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' @@ -103,7 +103,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { <>
- setOnAvatarError(x)} /> + setOnAvatarError(status === 'error')} />
{ diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index 908ef9c2e8..58331e3a77 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -4,6 +4,7 @@ import type { App } from '@/types/app' import { RiGraduationCapFill, } from '@remixicon/react' +import { useQueryClient } from '@tanstack/react-query' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -15,11 +16,11 @@ import PremiumBadge from '@/app/components/base/premium-badge' import { ToastContext } from '@/app/components/base/toast/context' import Collapse from '@/app/components/header/account-setting/collapse' import { IS_CE_EDITION, validPassword } from '@/config' -import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateUserProfile } from '@/service/common' import { useAppList } from '@/service/use-apps' +import { commonQueryKeys, useUserProfile } from '@/service/use-common' import DeleteAccount from '../delete-account' import AvatarWithEdit from './AvatarWithEdit' @@ -37,7 +38,10 @@ export default function AccountPage() { const { systemFeatures } = useGlobalPublicStore() const { data: appList } = useAppList({ page: 1, limit: 100, name: '' }) const apps = appList?.data || [] - const { mutateUserProfile, userProfile } = useAppContext() + const queryClient = useQueryClient() + const { data: userProfileResp } = useUserProfile() + const userProfile = userProfileResp?.profile + const mutateUserProfile = () => queryClient.invalidateQueries({ queryKey: commonQueryKeys.userProfile }) const { isEducationAccount } = useProviderContext() const { notify } = useContext(ToastContext) const [editNameModalVisible, setEditNameModalVisible] = useState(false) @@ -53,6 +57,9 @@ export default function AccountPage() { const [showConfirmPassword, setShowConfirmPassword] = useState(false) const [showUpdateEmail, setShowUpdateEmail] = useState(false) + if (!userProfile) + return null + const handleEditName = () => { setEditNameModalVisible(true) setEditName(userProfile.name) @@ -149,7 +156,7 @@ export default function AccountPage() {

{t('account.myAccount', { ns: 'common' })}

- +

{userProfile.name} diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index 8ea29e8e45..07b685b8c5 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -7,12 +7,11 @@ import { useRouter } from 'next/navigation' import { Fragment } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' -import Avatar from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/avatar' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import PremiumBadge from '@/app/components/base/premium-badge' -import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' -import { useLogout } from '@/service/use-common' +import { useLogout, useUserProfile } from '@/service/use-common' export type IAppSelector = { isMobile: boolean @@ -21,10 +20,15 @@ export type IAppSelector = { export default function AppSelector() { const router = useRouter() const { t } = useTranslation() - const { userProfile } = useAppContext() + const { data: userProfileResp } = useUserProfile() + const userProfile = userProfileResp?.profile const { isEducationAccount } = useProviderContext() const { mutateAsync: logout } = useLogout() + + if (!userProfile) + return null + const handleLogout = async () => { await logout() @@ -50,7 +54,7 @@ export default function AppSelector() { ${open && 'bg-components-panel-bg-blur'} `} > - +

{userProfile.email}
- +
diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index d718e0941d..835a1e702e 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -11,14 +11,13 @@ import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' -import Avatar from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' -import { useAppContext } from '@/context/app-context' -import { useIsLogin } from '@/service/use-common' +import { useIsLogin, useUserProfile } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' function buildReturnUrl(pathname: string, search: string) { @@ -62,7 +61,8 @@ export default function OAuthAuthorize() { const searchParams = useSearchParams() const client_id = decodeURIComponent(searchParams.get('client_id') || '') const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '') - const { userProfile } = useAppContext() + const { data: userProfileResp } = useUserProfile() + const userProfile = userProfileResp?.profile const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp() const hasNotifiedRef = useRef(false) @@ -138,7 +138,7 @@ export default function OAuthAuthorize() { {isLoggedIn && userProfile && (
- +
{userProfile.name}
{userProfile.email}
diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index 5c803a91f0..90cbac13a4 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -10,7 +10,7 @@ import { SubjectType } from '@/models/access-control' import { useSearchForWhiteListCandidates } from '@/service/access-control' import { cn } from '@/utils/classnames' import useAccessControlStore from '../../../../context/access-control-store' -import Avatar from '../../base/avatar' +import { Avatar } from '../../base/avatar' import Button from '../../base/button' import Checkbox from '../../base/checkbox' import Input from '../../base/input' @@ -203,7 +203,7 @@ function MemberItem({ member }: MemberItemProps) {
- +

{member.name}

diff --git a/web/app/components/app/app-access-control/specific-groups-or-members.tsx b/web/app/components/app/app-access-control/specific-groups-or-members.tsx index 8ca817c872..2c0e4b2694 100644 --- a/web/app/components/app/app-access-control/specific-groups-or-members.tsx +++ b/web/app/components/app/app-access-control/specific-groups-or-members.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects } from '@/service/access-control' import useAccessControlStore from '../../../../context/access-control-store' -import Avatar from '../../base/avatar' +import { Avatar } from '../../base/avatar' import Loading from '../../base/loading' import Tooltip from '../../base/tooltip' import AddMemberOrGroupDialog from './add-member-or-group-pop' @@ -106,7 +106,7 @@ function MemberItem({ member }: MemberItemProps) { }, [member, setSpecificMembers, specificMembers]) return ( } + icon={} onRemove={handleRemoveMember} >

{member.name}

diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx index 0bbed83a99..09a5ff6d07 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx @@ -172,12 +172,8 @@ describe('dataset-config/card-item', () => { const [editButton] = within(card).getAllByRole('button', { hidden: true }) await user.click(editButton) - expect(screen.getByText('Mock settings modal')).toBeInTheDocument() - await waitFor(() => { - expect(screen.getByRole('dialog')).toBeVisible() - }) - - fireEvent.click(screen.getByText('Save changes')) + expect(await screen.findByText('Mock settings modal')).toBeInTheDocument() + fireEvent.click(await screen.findByText('Save changes')) await waitFor(() => { expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' })) @@ -194,7 +190,7 @@ describe('dataset-config/card-item', () => { const card = screen.getByText(dataset.name).closest('.group') as HTMLElement const buttons = within(card).getAllByRole('button', { hidden: true }) - const deleteButton = buttons[buttons.length - 1] + const deleteButton = buttons.at(-1)! expect(deleteButton.className).not.toContain('action-btn-destructive') @@ -233,7 +229,7 @@ describe('dataset-config/card-item', () => { await user.click(editButton) expect(screen.getByText('Mock settings modal')).toBeInTheDocument() - const overlay = Array.from(document.querySelectorAll('[class]')) + const overlay = [...document.querySelectorAll('[class]')] .find(element => element.className.toString().includes('bg-black/30')) expect(overlay).toBeInTheDocument() diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx index d621bb3941..350ede8c96 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx @@ -91,7 +91,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({ })) vi.mock('@/app/components/base/avatar', () => ({ - default: ({ name }: { name: string }) =>
{name}
, + Avatar: ({ name }: { name: string }) =>
{name}
, })) const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({ diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index b7a7e90fca..e957fc24c4 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -7,7 +7,7 @@ import { useCallback, useMemo, } from 'react' -import Avatar from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/avatar' import Chat from '@/app/components/base/chat/chat' import { useChat } from '@/app/components/base/chat/chat/hooks' import { getLastAnswer } from '@/app/components/base/chat/utils' @@ -149,7 +149,7 @@ const ChatItem: FC = ({ suggestedQuestions={suggestedQuestions} onSend={doSend} showPromptLog - questionIcon={} + questionIcon={} allToolIcons={allToolIcons} hideLogModal noSpacing diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index addeb92297..84ff8b5ede 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -3,7 +3,7 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty import type { FileEntity } from '@/app/components/base/file-uploader/types' import { memo, useCallback, useImperativeHandle, useMemo } from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import Avatar from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/avatar' import Chat from '@/app/components/base/chat/chat' import { useChat } from '@/app/components/base/chat/chat/hooks' import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils' @@ -168,7 +168,7 @@ const DebugWithSingleModel = ( switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} onStopResponding={handleStop} showPromptLog - questionIcon={} + questionIcon={} allToolIcons={allToolIcons} onAnnotationEdited={handleAnnotationEdited} onAnnotationAdded={handleAnnotationAdded} diff --git a/web/app/components/base/amplitude/AmplitudeProvider.spec.tsx b/web/app/components/base/amplitude/AmplitudeProvider.spec.tsx new file mode 100644 index 0000000000..2402c84a3e --- /dev/null +++ b/web/app/components/base/amplitude/AmplitudeProvider.spec.tsx @@ -0,0 +1,139 @@ +import * as amplitude from '@amplitude/analytics-browser' +import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' +import { render } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider' + +const mockConfig = vi.hoisted(() => ({ + AMPLITUDE_API_KEY: 'test-api-key', + IS_CLOUD_EDITION: true, +})) + +vi.mock('@/config', () => mockConfig) + +vi.mock('@amplitude/analytics-browser', () => ({ + init: vi.fn(), + add: vi.fn(), +})) + +vi.mock('@amplitude/plugin-session-replay-browser', () => ({ + sessionReplayPlugin: vi.fn(() => ({ name: 'session-replay' })), +})) + +describe('AmplitudeProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + mockConfig.AMPLITUDE_API_KEY = 'test-api-key' + mockConfig.IS_CLOUD_EDITION = true + }) + + describe('isAmplitudeEnabled', () => { + it('returns true when cloud edition and api key present', () => { + expect(isAmplitudeEnabled()).toBe(true) + }) + + it('returns false when cloud edition but no api key', () => { + mockConfig.AMPLITUDE_API_KEY = '' + expect(isAmplitudeEnabled()).toBe(false) + }) + + it('returns false when not cloud edition', () => { + mockConfig.IS_CLOUD_EDITION = false + expect(isAmplitudeEnabled()).toBe(false) + }) + }) + + describe('Component', () => { + it('initializes amplitude when enabled', () => { + render() + + expect(amplitude.init).toHaveBeenCalledWith('test-api-key', expect.any(Object)) + expect(sessionReplayPlugin).toHaveBeenCalledWith({ sampleRate: 0.8 }) + expect(amplitude.add).toHaveBeenCalledTimes(2) + }) + + it('does not initialize amplitude when disabled', () => { + mockConfig.AMPLITUDE_API_KEY = '' + render() + + expect(amplitude.init).not.toHaveBeenCalled() + expect(amplitude.add).not.toHaveBeenCalled() + }) + + it('pageNameEnrichmentPlugin logic works as expected', async () => { + render() + const plugin = vi.mocked(amplitude.add).mock.calls[0]?.[0] as amplitude.Types.EnrichmentPlugin | undefined + expect(plugin).toBeDefined() + if (!plugin?.execute || !plugin.setup) + throw new Error('Expected page-name-enrichment plugin with setup/execute') + + expect(plugin.name).toBe('page-name-enrichment') + + const execute = plugin.execute + const setup = plugin.setup + type SetupFn = NonNullable + const getPageTitle = (evt: amplitude.Types.Event | null | undefined) => + (evt?.event_properties as Record | undefined)?.['[Amplitude] Page Title'] + + await setup( + {} as Parameters[0], + {} as Parameters[1], + ) + + const originalWindowLocation = window.location + try { + Object.defineProperty(window, 'location', { + value: { pathname: '/datasets' }, + writable: true, + }) + const event: amplitude.Types.Event = { + event_type: '[Amplitude] Page Viewed', + event_properties: {}, + } + const result = await execute(event) + expect(getPageTitle(result)).toBe('Knowledge') + window.location.pathname = '/' + await execute(event) + expect(getPageTitle(event)).toBe('Home') + window.location.pathname = '/apps' + await execute(event) + expect(getPageTitle(event)).toBe('Studio') + window.location.pathname = '/explore' + await execute(event) + expect(getPageTitle(event)).toBe('Explore') + window.location.pathname = '/tools' + await execute(event) + expect(getPageTitle(event)).toBe('Tools') + window.location.pathname = '/account' + await execute(event) + expect(getPageTitle(event)).toBe('Account') + window.location.pathname = '/signin' + await execute(event) + expect(getPageTitle(event)).toBe('Sign In') + window.location.pathname = '/signup' + await execute(event) + expect(getPageTitle(event)).toBe('Sign Up') + window.location.pathname = '/unknown' + await execute(event) + expect(getPageTitle(event)).toBe('Unknown') + const otherEvent = { + event_type: 'Button Clicked', + event_properties: {}, + } as amplitude.Types.Event + const otherResult = await execute(otherEvent) + expect(getPageTitle(otherResult)).toBeUndefined() + const noPropsEvent = { + event_type: '[Amplitude] Page Viewed', + } as amplitude.Types.Event + const noPropsResult = await execute(noPropsEvent) + expect(noPropsResult?.event_properties).toBeUndefined() + } + finally { + Object.defineProperty(window, 'location', { + value: originalWindowLocation, + writable: true, + }) + } + }) + }) +}) diff --git a/web/app/components/base/amplitude/index.spec.ts b/web/app/components/base/amplitude/index.spec.ts new file mode 100644 index 0000000000..919c0b68d1 --- /dev/null +++ b/web/app/components/base/amplitude/index.spec.ts @@ -0,0 +1,32 @@ +import { describe, expect, it } from 'vitest' +import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider' +import indexDefault, { + isAmplitudeEnabled as indexIsAmplitudeEnabled, + resetUser, + setUserId, + setUserProperties, + trackEvent, +} from './index' +import { + resetUser as utilsResetUser, + setUserId as utilsSetUserId, + setUserProperties as utilsSetUserProperties, + trackEvent as utilsTrackEvent, +} from './utils' + +describe('Amplitude index exports', () => { + it('exports AmplitudeProvider as default', () => { + expect(indexDefault).toBe(AmplitudeProvider) + }) + + it('exports isAmplitudeEnabled', () => { + expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled) + }) + + it('exports utils', () => { + expect(resetUser).toBe(utilsResetUser) + expect(setUserId).toBe(utilsSetUserId) + expect(setUserProperties).toBe(utilsSetUserProperties) + expect(trackEvent).toBe(utilsTrackEvent) + }) +}) diff --git a/web/app/components/base/amplitude/utils.spec.ts b/web/app/components/base/amplitude/utils.spec.ts new file mode 100644 index 0000000000..c69fc93aa4 --- /dev/null +++ b/web/app/components/base/amplitude/utils.spec.ts @@ -0,0 +1,119 @@ +import { resetUser, setUserId, setUserProperties, trackEvent } from './utils' + +const mockState = vi.hoisted(() => ({ + enabled: true, +})) + +const mockTrack = vi.hoisted(() => vi.fn()) +const mockSetUserId = vi.hoisted(() => vi.fn()) +const mockIdentify = vi.hoisted(() => vi.fn()) +const mockReset = vi.hoisted(() => vi.fn()) + +const MockIdentify = vi.hoisted(() => + class { + setCalls: Array<[string, unknown]> = [] + + set(key: string, value: unknown) { + this.setCalls.push([key, value]) + return this + } + }, +) + +vi.mock('./AmplitudeProvider', () => ({ + isAmplitudeEnabled: () => mockState.enabled, +})) + +vi.mock('@amplitude/analytics-browser', () => ({ + track: (...args: unknown[]) => mockTrack(...args), + setUserId: (...args: unknown[]) => mockSetUserId(...args), + identify: (...args: unknown[]) => mockIdentify(...args), + reset: (...args: unknown[]) => mockReset(...args), + Identify: MockIdentify, +})) + +describe('amplitude utils', () => { + beforeEach(() => { + vi.clearAllMocks() + mockState.enabled = true + }) + + describe('trackEvent', () => { + it('should call amplitude.track when amplitude is enabled', () => { + trackEvent('dataset_created', { source: 'wizard' }) + + expect(mockTrack).toHaveBeenCalledTimes(1) + expect(mockTrack).toHaveBeenCalledWith('dataset_created', { source: 'wizard' }) + }) + + it('should not call amplitude.track when amplitude is disabled', () => { + mockState.enabled = false + + trackEvent('dataset_created', { source: 'wizard' }) + + expect(mockTrack).not.toHaveBeenCalled() + }) + }) + + describe('setUserId', () => { + it('should call amplitude.setUserId when amplitude is enabled', () => { + setUserId('user-123') + + expect(mockSetUserId).toHaveBeenCalledTimes(1) + expect(mockSetUserId).toHaveBeenCalledWith('user-123') + }) + + it('should not call amplitude.setUserId when amplitude is disabled', () => { + mockState.enabled = false + + setUserId('user-123') + + expect(mockSetUserId).not.toHaveBeenCalled() + }) + }) + + describe('setUserProperties', () => { + it('should build identify event and call amplitude.identify when amplitude is enabled', () => { + const properties: Record = { + role: 'owner', + seats: 3, + verified: true, + } + + setUserProperties(properties) + + expect(mockIdentify).toHaveBeenCalledTimes(1) + const identifyArg = mockIdentify.mock.calls[0][0] as InstanceType + expect(identifyArg).toBeInstanceOf(MockIdentify) + expect(identifyArg.setCalls).toEqual([ + ['role', 'owner'], + ['seats', 3], + ['verified', true], + ]) + }) + + it('should not call amplitude.identify when amplitude is disabled', () => { + mockState.enabled = false + + setUserProperties({ role: 'owner' }) + + expect(mockIdentify).not.toHaveBeenCalled() + }) + }) + + describe('resetUser', () => { + it('should call amplitude.reset when amplitude is enabled', () => { + resetUser() + + expect(mockReset).toHaveBeenCalledTimes(1) + }) + + it('should not call amplitude.reset when amplitude is disabled', () => { + mockState.enabled = false + + resetUser() + + expect(mockReset).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/base/audio-btn/__tests__/audio.player.manager.spec.ts b/web/app/components/base/audio-btn/__tests__/audio.player.manager.spec.ts new file mode 100644 index 0000000000..c613aa2c11 --- /dev/null +++ b/web/app/components/base/audio-btn/__tests__/audio.player.manager.spec.ts @@ -0,0 +1,148 @@ +import { AudioPlayerManager } from '../audio.player.manager' + +type AudioCallback = ((event: string) => void) | null +type AudioPlayerCtorArgs = [ + string, + boolean, + string | undefined, + string | null | undefined, + string | undefined, + AudioCallback, +] + +type MockAudioPlayerInstance = { + setCallback: ReturnType + pauseAudio: ReturnType + resetMsgId: ReturnType + cacheBuffers: Array + sourceBuffer: { + abort: ReturnType + } | undefined +} + +const mockState = vi.hoisted(() => ({ + instances: [] as MockAudioPlayerInstance[], +})) + +const mockAudioPlayerConstructor = vi.hoisted(() => vi.fn()) + +const MockAudioPlayer = vi.hoisted(() => { + return class MockAudioPlayerClass { + setCallback = vi.fn() + pauseAudio = vi.fn() + resetMsgId = vi.fn() + cacheBuffers = [new ArrayBuffer(1)] + sourceBuffer = { abort: vi.fn() } + + constructor(...args: AudioPlayerCtorArgs) { + mockAudioPlayerConstructor(...args) + mockState.instances.push(this as unknown as MockAudioPlayerInstance) + } + } +}) + +vi.mock('@/app/components/base/audio-btn/audio', () => ({ + default: MockAudioPlayer, +})) + +describe('AudioPlayerManager', () => { + beforeEach(() => { + vi.clearAllMocks() + mockState.instances = [] + Reflect.set(AudioPlayerManager, 'instance', undefined) + }) + + describe('getInstance', () => { + it('should return the same singleton instance across calls', () => { + const first = AudioPlayerManager.getInstance() + const second = AudioPlayerManager.getInstance() + + expect(first).toBe(second) + }) + }) + + describe('getAudioPlayer', () => { + it('should create a new audio player when no existing player is cached', () => { + const manager = AudioPlayerManager.getInstance() + const callback = vi.fn() + + const result = manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback) + + expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(1) + expect(mockAudioPlayerConstructor).toHaveBeenCalledWith( + '/text-to-audio', + false, + 'msg-1', + 'hello', + 'en-US', + callback, + ) + expect(result).toBe(mockState.instances[0]) + }) + + it('should reuse existing player and update callback when msg id is unchanged', () => { + const manager = AudioPlayerManager.getInstance() + const firstCallback = vi.fn() + const secondCallback = vi.fn() + + const first = manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', firstCallback) + const second = manager.getAudioPlayer('/ignored', true, 'msg-1', 'ignored', 'fr-FR', secondCallback) + + expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(1) + expect(first).toBe(second) + expect(mockState.instances[0].setCallback).toHaveBeenCalledTimes(1) + expect(mockState.instances[0].setCallback).toHaveBeenCalledWith(secondCallback) + }) + + it('should cleanup existing player and create a new one when msg id changes', () => { + const manager = AudioPlayerManager.getInstance() + const callback = vi.fn() + manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback) + const previous = mockState.instances[0] + + const next = manager.getAudioPlayer('/apps/1/text-to-audio', false, 'msg-2', 'world', 'en-US', callback) + + expect(previous.pauseAudio).toHaveBeenCalledTimes(1) + expect(previous.cacheBuffers).toEqual([]) + expect(previous.sourceBuffer?.abort).toHaveBeenCalledTimes(1) + expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(2) + expect(next).toBe(mockState.instances[1]) + }) + + it('should swallow cleanup errors and still create a new player', () => { + const manager = AudioPlayerManager.getInstance() + const callback = vi.fn() + manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback) + const previous = mockState.instances[0] + previous.pauseAudio.mockImplementation(() => { + throw new Error('cleanup failure') + }) + + expect(() => { + manager.getAudioPlayer('/apps/1/text-to-audio', false, 'msg-2', 'world', 'en-US', callback) + }).not.toThrow() + + expect(previous.pauseAudio).toHaveBeenCalledTimes(1) + expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(2) + }) + }) + + describe('resetMsgId', () => { + it('should forward reset message id to the cached audio player when present', () => { + const manager = AudioPlayerManager.getInstance() + const callback = vi.fn() + manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback) + + manager.resetMsgId('msg-updated') + + expect(mockState.instances[0].resetMsgId).toHaveBeenCalledTimes(1) + expect(mockState.instances[0].resetMsgId).toHaveBeenCalledWith('msg-updated') + }) + + it('should not throw when resetting message id without an audio player', () => { + const manager = AudioPlayerManager.getInstance() + + expect(() => manager.resetMsgId('msg-updated')).not.toThrow() + }) + }) +}) diff --git a/web/app/components/base/audio-btn/__tests__/audio.spec.ts b/web/app/components/base/audio-btn/__tests__/audio.spec.ts new file mode 100644 index 0000000000..00ffea2dfb --- /dev/null +++ b/web/app/components/base/audio-btn/__tests__/audio.spec.ts @@ -0,0 +1,610 @@ +import { Buffer } from 'node:buffer' +import { waitFor } from '@testing-library/react' +import { AppSourceType } from '@/service/share' +import AudioPlayer from '../audio' + +const mockToastNotify = vi.hoisted(() => vi.fn()) +const mockTextToAudioStream = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/base/toast', () => ({ + default: { + notify: (...args: unknown[]) => mockToastNotify(...args), + }, +})) + +vi.mock('@/service/share', () => ({ + AppSourceType: { + webApp: 'webApp', + installedApp: 'installedApp', + }, + textToAudioStream: (...args: unknown[]) => mockTextToAudioStream(...args), +})) + +type AudioEventName = 'ended' | 'paused' | 'loaded' | 'play' | 'timeupdate' | 'loadeddate' | 'canplay' | 'error' | 'sourceopen' + +type AudioEventListener = () => void + +type ReaderResult = { + value: Uint8Array | undefined + done: boolean +} + +type Reader = { + read: () => Promise +} + +type AudioResponse = { + status: number + body: { + getReader: () => Reader + } +} + +class MockSourceBuffer { + updating = false + appendBuffer = vi.fn((_buffer: ArrayBuffer) => undefined) + abort = vi.fn(() => undefined) +} + +class MockMediaSource { + readyState: 'open' | 'closed' = 'open' + sourceBuffer = new MockSourceBuffer() + private listeners: Partial> = {} + + addEventListener = vi.fn((event: AudioEventName, listener: AudioEventListener) => { + const listeners = this.listeners[event] || [] + listeners.push(listener) + this.listeners[event] = listeners + }) + + addSourceBuffer = vi.fn((_contentType: string) => this.sourceBuffer) + endOfStream = vi.fn(() => undefined) + + emit(event: AudioEventName) { + const listeners = this.listeners[event] || [] + listeners.forEach((listener) => { + listener() + }) + } +} + +class MockAudio { + src = '' + autoplay = false + disableRemotePlayback = false + controls = false + paused = true + ended = false + played: unknown = null + private listeners: Partial> = {} + + addEventListener = vi.fn((event: AudioEventName, listener: AudioEventListener) => { + const listeners = this.listeners[event] || [] + listeners.push(listener) + this.listeners[event] = listeners + }) + + play = vi.fn(async () => { + this.paused = false + }) + + pause = vi.fn(() => { + this.paused = true + }) + + emit(event: AudioEventName) { + const listeners = this.listeners[event] || [] + listeners.forEach((listener) => { + listener() + }) + } +} + +class MockAudioContext { + state: 'running' | 'suspended' = 'running' + destination = {} + connect = vi.fn(() => undefined) + createMediaElementSource = vi.fn((_audio: MockAudio) => ({ + connect: this.connect, + })) + + resume = vi.fn(async () => { + this.state = 'running' + }) + + suspend = vi.fn(() => { + this.state = 'suspended' + }) +} + +const testState = { + mediaSources: [] as MockMediaSource[], + audios: [] as MockAudio[], + audioContexts: [] as MockAudioContext[], +} + +class MockMediaSourceCtor extends MockMediaSource { + constructor() { + super() + testState.mediaSources.push(this) + } +} + +class MockAudioCtor extends MockAudio { + constructor() { + super() + testState.audios.push(this) + } +} + +class MockAudioContextCtor extends MockAudioContext { + constructor() { + super() + testState.audioContexts.push(this) + } +} + +const originalAudio = globalThis.Audio +const originalAudioContext = globalThis.AudioContext +const originalCreateObjectURL = globalThis.URL.createObjectURL +const originalMediaSource = window.MediaSource +const originalManagedMediaSource = window.ManagedMediaSource + +const setMediaSourceSupport = (options: { mediaSource: boolean, managedMediaSource: boolean }) => { + Object.defineProperty(window, 'MediaSource', { + configurable: true, + writable: true, + value: options.mediaSource ? MockMediaSourceCtor : undefined, + }) + Object.defineProperty(window, 'ManagedMediaSource', { + configurable: true, + writable: true, + value: options.managedMediaSource ? MockMediaSourceCtor : undefined, + }) +} + +const makeAudioResponse = (status: number, reads: ReaderResult[]): AudioResponse => { + const read = vi.fn<() => Promise>() + reads.forEach((result) => { + read.mockResolvedValueOnce(result) + }) + + return { + status, + body: { + getReader: () => ({ read }), + }, + } +} + +describe('AudioPlayer', () => { + beforeEach(() => { + vi.clearAllMocks() + testState.mediaSources = [] + testState.audios = [] + testState.audioContexts = [] + + Object.defineProperty(globalThis, 'Audio', { + configurable: true, + writable: true, + value: MockAudioCtor, + }) + Object.defineProperty(globalThis, 'AudioContext', { + configurable: true, + writable: true, + value: MockAudioContextCtor, + }) + Object.defineProperty(globalThis.URL, 'createObjectURL', { + configurable: true, + writable: true, + value: vi.fn(() => 'blob:mock-url'), + }) + + setMediaSourceSupport({ mediaSource: true, managedMediaSource: false }) + }) + + afterAll(() => { + Object.defineProperty(globalThis, 'Audio', { + configurable: true, + writable: true, + value: originalAudio, + }) + Object.defineProperty(globalThis, 'AudioContext', { + configurable: true, + writable: true, + value: originalAudioContext, + }) + Object.defineProperty(globalThis.URL, 'createObjectURL', { + configurable: true, + writable: true, + value: originalCreateObjectURL, + }) + Object.defineProperty(window, 'MediaSource', { + configurable: true, + writable: true, + value: originalMediaSource, + }) + Object.defineProperty(window, 'ManagedMediaSource', { + configurable: true, + writable: true, + value: originalManagedMediaSource, + }) + }) + + describe('constructor behavior', () => { + it('should initialize media source, audio, and media element source when MediaSource exists', () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + const audioContext = testState.audioContexts[0] + const mediaSource = testState.mediaSources[0] + + expect(player.mediaSource).toBe(mediaSource as unknown as MediaSource) + expect(globalThis.URL.createObjectURL).toHaveBeenCalledTimes(1) + expect(audio.src).toBe('blob:mock-url') + expect(audio.autoplay).toBe(true) + expect(audioContext.createMediaElementSource).toHaveBeenCalledWith(audio) + expect(audioContext.connect).toHaveBeenCalledTimes(1) + }) + + it('should notify unsupported browser when no MediaSource implementation exists', () => { + setMediaSourceSupport({ mediaSource: false, managedMediaSource: false }) + + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null) + const audio = testState.audios[0] + + expect(player.mediaSource).toBeNull() + expect(audio.src).toBe('') + expect(mockToastNotify).toHaveBeenCalledTimes(1) + expect(mockToastNotify).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'error', + }), + ) + }) + + it('should configure fallback audio controls when ManagedMediaSource is used', () => { + setMediaSourceSupport({ mediaSource: false, managedMediaSource: true }) + + // Create with callback to ensure constructor path completes with fallback source. + const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, vi.fn()) + const audio = testState.audios[0] + + expect(player.mediaSource).not.toBeNull() + expect(audio.disableRemotePlayback).toBe(true) + expect(audio.controls).toBe(true) + }) + }) + + describe('event wiring', () => { + it('should forward registered audio events to callback', () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + + audio.emit('play') + audio.emit('ended') + audio.emit('error') + audio.emit('paused') + audio.emit('loaded') + audio.emit('timeupdate') + audio.emit('loadeddate') + audio.emit('canplay') + + expect(player.callback).toBe(callback) + expect(callback).toHaveBeenCalledWith('play') + expect(callback).toHaveBeenCalledWith('ended') + expect(callback).toHaveBeenCalledWith('error') + expect(callback).toHaveBeenCalledWith('paused') + expect(callback).toHaveBeenCalledWith('loaded') + expect(callback).toHaveBeenCalledWith('timeupdate') + expect(callback).toHaveBeenCalledWith('loadeddate') + expect(callback).toHaveBeenCalledWith('canplay') + }) + + it('should initialize source buffer only once when sourceopen fires multiple times', () => { + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', vi.fn()) + const mediaSource = testState.mediaSources[0] + + mediaSource.emit('sourceopen') + mediaSource.emit('sourceopen') + + expect(mediaSource.addSourceBuffer).toHaveBeenCalledTimes(1) + expect(player.sourceBuffer).toBe(mediaSource.sourceBuffer) + }) + }) + + describe('playback control', () => { + it('should request streaming audio when playAudio is called before loading', async () => { + mockTextToAudioStream.mockResolvedValue( + makeAudioResponse(200, [ + { value: new Uint8Array([4, 5]), done: false }, + { value: new Uint8Array([1, 2, 3]), done: true }, + ]), + ) + + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', vi.fn()) + player.playAudio() + + await waitFor(() => { + expect(mockTextToAudioStream).toHaveBeenCalledTimes(1) + }) + + expect(mockTextToAudioStream).toHaveBeenCalledWith( + '/text-to-audio', + AppSourceType.webApp, + { content_type: 'audio/mpeg' }, + { + message_id: 'msg-1', + streaming: true, + voice: 'en-US', + text: 'hello', + }, + ) + expect(player.isLoadData).toBe(true) + }) + + it('should emit error callback and reset load flag when stream response status is not 200', async () => { + const callback = vi.fn() + mockTextToAudioStream.mockResolvedValue( + makeAudioResponse(500, [{ value: new Uint8Array([1]), done: true }]), + ) + + const player = new AudioPlayer('/text-to-audio', false, 'msg-2', 'world', undefined, callback) + player.playAudio() + + await waitFor(() => { + expect(callback).toHaveBeenCalledWith('error') + }) + expect(player.isLoadData).toBe(false) + }) + + it('should resume and play immediately when playAudio is called in suspended loaded state', async () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback) + const audio = testState.audios[0] + const audioContext = testState.audioContexts[0] + + player.isLoadData = true + audioContext.state = 'suspended' + player.playAudio() + await Promise.resolve() + + expect(audioContext.resume).toHaveBeenCalledTimes(1) + expect(audio.play).toHaveBeenCalledTimes(1) + expect(callback).toHaveBeenCalledWith('play') + }) + + it('should play ended audio when data is already loaded', () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback) + const audio = testState.audios[0] + const audioContext = testState.audioContexts[0] + + player.isLoadData = true + audioContext.state = 'running' + audio.ended = true + player.playAudio() + + expect(audio.play).toHaveBeenCalledTimes(1) + expect(callback).toHaveBeenCalledWith('play') + }) + + it('should only emit play callback without replaying when loaded audio is already playing', () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback) + const audio = testState.audios[0] + const audioContext = testState.audioContexts[0] + + player.isLoadData = true + audioContext.state = 'running' + audio.ended = false + player.playAudio() + + expect(audio.play).not.toHaveBeenCalled() + expect(callback).toHaveBeenCalledWith('play') + }) + + it('should emit error callback when stream request throws', async () => { + const callback = vi.fn() + mockTextToAudioStream.mockRejectedValue(new Error('network failed')) + const player = new AudioPlayer('/text-to-audio', false, 'msg-2', 'world', undefined, callback) + + player.playAudio() + + await waitFor(() => { + expect(callback).toHaveBeenCalledWith('error') + }) + expect(player.isLoadData).toBe(false) + }) + + it('should call pause flow and notify paused event when pauseAudio is invoked', () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + const audioContext = testState.audioContexts[0] + + player.pauseAudio() + + expect(callback).toHaveBeenCalledWith('paused') + expect(audio.pause).toHaveBeenCalledTimes(1) + expect(audioContext.suspend).toHaveBeenCalledTimes(1) + }) + }) + + describe('message and direct-audio helpers', () => { + it('should update message id through resetMsgId', () => { + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null) + + player.resetMsgId('msg-2') + + expect(player.msgId).toBe('msg-2') + }) + + it('should end stream without playback when playAudioWithAudio receives empty content', async () => { + vi.useFakeTimers() + try { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const mediaSource = testState.mediaSources[0] + + await player.playAudioWithAudio('', true) + await vi.advanceTimersByTimeAsync(40) + + expect(player.isLoadData).toBe(false) + expect(player.cacheBuffers).toHaveLength(0) + expect(mediaSource.endOfStream).toHaveBeenCalledTimes(1) + expect(callback).not.toHaveBeenCalledWith('play') + } + finally { + vi.useRealTimers() + } + }) + + it('should decode base64 and start playback when playAudioWithAudio is called with playable content', async () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + const audioContext = testState.audioContexts[0] + const mediaSource = testState.mediaSources[0] + const audioBase64 = Buffer.from('hello').toString('base64') + + mediaSource.emit('sourceopen') + audio.paused = true + await player.playAudioWithAudio(audioBase64, true) + await Promise.resolve() + + expect(player.isLoadData).toBe(true) + expect(player.cacheBuffers).toHaveLength(0) + expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1) + const appendedAudioData = mediaSource.sourceBuffer.appendBuffer.mock.calls[0][0] + expect(appendedAudioData).toBeInstanceOf(ArrayBuffer) + expect(appendedAudioData.byteLength).toBeGreaterThan(0) + expect(audioContext.resume).toHaveBeenCalledTimes(1) + expect(audio.play).toHaveBeenCalledTimes(1) + expect(callback).toHaveBeenCalledWith('play') + }) + + it('should skip playback when playAudioWithAudio is called with play=false', async () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + const audioContext = testState.audioContexts[0] + + await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), false) + + expect(player.isLoadData).toBe(false) + expect(audioContext.resume).not.toHaveBeenCalled() + expect(audio.play).not.toHaveBeenCalled() + expect(callback).not.toHaveBeenCalledWith('play') + }) + + it('should play immediately for ended audio in playAudioWithAudio', async () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + + audio.paused = false + audio.ended = true + await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true) + + expect(audio.play).toHaveBeenCalledTimes(1) + expect(callback).toHaveBeenCalledWith('play') + }) + + it('should not replay when played list exists in playAudioWithAudio', async () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + + audio.paused = false + audio.ended = false + audio.played = {} + await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true) + + expect(audio.play).not.toHaveBeenCalled() + expect(callback).not.toHaveBeenCalledWith('play') + }) + + it('should replay when paused is false and played list is empty in playAudioWithAudio', async () => { + const callback = vi.fn() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback) + const audio = testState.audios[0] + + audio.paused = false + audio.ended = false + audio.played = null + await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true) + + expect(audio.play).toHaveBeenCalledTimes(1) + expect(callback).toHaveBeenCalledWith('play') + }) + }) + + describe('buffering internals', () => { + it('should finish stream when receiveAudioData gets an undefined chunk', () => { + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null) + const finishStream = vi + .spyOn(player as unknown as { finishStream: () => void }, 'finishStream') + .mockImplementation(() => { }) + + ; (player as unknown as { receiveAudioData: (data: Uint8Array | undefined) => void }).receiveAudioData(undefined) + + expect(finishStream).toHaveBeenCalledTimes(1) + }) + + it('should finish stream when receiveAudioData gets empty bytes while source is open', () => { + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null) + const finishStream = vi + .spyOn(player as unknown as { finishStream: () => void }, 'finishStream') + .mockImplementation(() => { }) + + ; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array(0)) + + expect(finishStream).toHaveBeenCalledTimes(1) + }) + + it('should queue incoming buffer when source buffer is updating', () => { + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null) + const mediaSource = testState.mediaSources[0] + mediaSource.emit('sourceopen') + mediaSource.sourceBuffer.updating = true + + ; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array([1, 2, 3])) + + expect(player.cacheBuffers.length).toBe(1) + }) + + it('should append previously queued buffer before new one when source buffer is idle', () => { + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null) + const mediaSource = testState.mediaSources[0] + mediaSource.emit('sourceopen') + + const existingBuffer = new ArrayBuffer(2) + player.cacheBuffers = [existingBuffer] + mediaSource.sourceBuffer.updating = false + + ; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array([9])) + + expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1) + expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledWith(existingBuffer) + expect(player.cacheBuffers.length).toBe(1) + }) + + it('should append cache chunks and end stream when finishStream drains buffers', () => { + vi.useFakeTimers() + const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null) + const mediaSource = testState.mediaSources[0] + mediaSource.emit('sourceopen') + mediaSource.sourceBuffer.updating = false + player.cacheBuffers = [new ArrayBuffer(3)] + + ; (player as unknown as { finishStream: () => void }).finishStream() + vi.advanceTimersByTime(50) + + expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1) + expect(mediaSource.endOfStream).toHaveBeenCalledTimes(1) + vi.useRealTimers() + }) + }) +}) diff --git a/web/app/components/base/audio-gallery/AudioPlayer.tsx b/web/app/components/base/audio-gallery/AudioPlayer.tsx index 4e5d5e61ab..331dd06c67 100644 --- a/web/app/components/base/audio-gallery/AudioPlayer.tsx +++ b/web/app/components/base/audio-gallery/AudioPlayer.tsx @@ -26,6 +26,7 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { useEffect(() => { const audio = audioRef.current + /* v8 ignore next 2 - @preserve */ if (!audio) return @@ -217,6 +218,7 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { const drawWaveform = useCallback(() => { const canvas = canvasRef.current + /* v8 ignore next 2 - @preserve */ if (!canvas) return @@ -268,14 +270,20 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { drawWaveform() }, [drawWaveform, bufferedTime, hasStartedPlaying]) - const handleMouseMove = useCallback((e: React.MouseEvent) => { + const handleMouseMove = useCallback((e: React.MouseEvent | React.TouchEvent) => { const canvas = canvasRef.current const audio = audioRef.current if (!canvas || !audio) return + const clientX = 'touches' in e + ? e.touches[0]?.clientX ?? e.changedTouches[0]?.clientX + : e.clientX + if (clientX === undefined) + return + const rect = canvas.getBoundingClientRect() - const percent = Math.min(Math.max(0, e.clientX - rect.left), rect.width) / rect.width + const percent = Math.min(Math.max(0, clientX - rect.left), rect.width) / rect.width const time = percent * duration // Check if the hovered position is within a buffered range before updating hoverTime @@ -289,20 +297,22 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { return (
-