diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh
index 637593b9de..b92d4c35a8 100755
--- a/.devcontainer/post_create_command.sh
+++ b/.devcontainer/post_create_command.sh
@@ -7,7 +7,7 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
-echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
+echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml
new file mode 100644
index 0000000000..c57da7cb5f
--- /dev/null
+++ b/.github/actions/setup-web/action.yml
@@ -0,0 +1,33 @@
+name: Setup Web Environment
+description: Setup pnpm, Node.js, and install web dependencies.
+
+inputs:
+ node-version:
+ description: Node.js version to use
+ required: false
+ default: "22"
+ install-dependencies:
+ description: Whether to install web dependencies after setting up Node.js
+ required: false
+ default: "true"
+
+runs:
+ using: composite
+ steps:
+ - name: Install pnpm
+ uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Setup Node.js
+ uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
+ with:
+ node-version: ${{ inputs.node-version }}
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Install dependencies
+ if: ${{ inputs.install-dependencies == 'true' }}
+ shell: bash
+ run: pnpm --dir web install --frozen-lockfile
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index 78f6eefd0d..5715b1e83f 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -19,19 +19,3 @@ updates:
uv-dependencies:
patterns:
- "*"
- - package-ecosystem: "npm"
- directory: "/web"
- schedule:
- interval: "weekly"
- open-pull-requests-limit: 2
- groups:
- storybook:
- patterns:
- - "storybook"
- - "@storybook/*"
- npm-dependencies:
- patterns:
- - "*"
- exclude-patterns:
- - "storybook"
- - "@storybook/*"
diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml
new file mode 100644
index 0000000000..c0d1818691
--- /dev/null
+++ b/.github/workflows/anti-slop.yml
@@ -0,0 +1,19 @@
+name: Anti-Slop PR Check
+
+on:
+ pull_request_target:
+ types: [opened, edited, synchronize]
+
+permissions:
+ pull-requests: write
+ contents: read
+
+jobs:
+ anti-slop:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: peakoss/anti-slop@v0
+ with:
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ close-pr: false
+ failure-add-pr-labels: "needs-revision"
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index 52e3272f99..03f6917dca 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v7
+ uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -51,7 +51,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
- uses: hoverkraft-tech/compose-action@v2
+ uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index 4571fd1cd1..4b48e741df 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -12,22 +12,34 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v6
+ - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Check Docker Compose inputs
id: docker-compose-changes
- uses: tj-actions/changed-files@v47
+ uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- - uses: actions/setup-python@v6
+ - name: Check web inputs
+ id: web-changes
+ uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ with:
+ files: |
+ web/**
+ - name: Check api inputs
+ id: api-changes
+ uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ with:
+ files: |
+ api/**
+ - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: "3.11"
- - uses: astral-sh/setup-uv@v7
+ - uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
@@ -35,7 +47,8 @@ jobs:
cd docker
./generate_docker_compose
- - run: |
+ - if: steps.api-changes.outputs.any_changed == 'true'
+ run: |
cd api
uv sync --dev
# fmt first to avoid line too long
@@ -46,11 +59,13 @@ jobs:
uv run ruff format ..
- name: count migration progress
+ if: steps.api-changes.outputs.any_changed == 'true'
run: |
cd api
./cnt_base.sh
- name: ast-grep
+ if: steps.api-changes.outputs.any_changed == 'true'
run: |
# ast-grep exits 1 if no matches are found; allow idempotent runs.
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
@@ -84,4 +99,16 @@ jobs:
run: |
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
+ - name: Setup web environment
+ if: steps.web-changes.outputs.any_changed == 'true'
+ uses: ./.github/actions/setup-web
+ with:
+ node-version: "24"
+
+ - name: ESLint autofix
+ if: steps.web-changes.outputs.any_changed == 'true'
+ run: |
+ cd web
+ pnpm eslint --concurrency=2 --prune-suppressions --quiet || true
+
+ - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index ac7f3a6b48..94466d151c 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -53,26 +53,26 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
- uses: docker/login-action@v3
+ uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
- uses: docker/setup-qemu-action@v3
+ uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v3
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Extract metadata for Docker
id: meta
- uses: docker/metadata-action@v5
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
with:
images: ${{ env[matrix.image_name_env] }}
- name: Build Docker image
id: build
- uses: docker/build-push-action@v6
+ uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
context: "{{defaultContext}}:${{ matrix.context }}"
platforms: ${{ matrix.platform }}
@@ -91,7 +91,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
- uses: actions/upload-artifact@v6
+ uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
@@ -113,21 +113,21 @@ jobs:
context: "web"
steps:
- name: Download digests
- uses: actions/download-artifact@v7
+ uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*
merge-multiple: true
- name: Login to Docker Hub
- uses: docker/login-action@v3
+ uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Extract metadata for Docker
id: meta
- uses: docker/metadata-action@v5
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml
index e20cf9850b..84a506a325 100644
--- a/.github/workflows/db-migration-test.yml
+++ b/.github/workflows/db-migration-test.yml
@@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v7
+ uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: "3.12"
@@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
- uses: hoverkraft-tech/compose-action@v2.0.2
+ uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v7
+ uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: "3.12"
@@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
- uses: hoverkraft-tech/compose-action@v2.0.2
+ uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
diff --git a/.github/workflows/deploy-agent-dev.yml b/.github/workflows/deploy-agent-dev.yml
index dd759f7ba5..cd5fe9242e 100644
--- a/.github/workflows/deploy-agent-dev.yml
+++ b/.github/workflows/deploy-agent-dev.yml
@@ -19,7 +19,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v1
+ uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml
index 38fa0b9a7f..954537663a 100644
--- a/.github/workflows/deploy-dev.yml
+++ b/.github/workflows/deploy-dev.yml
@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v1
+ uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml
index a3fd52afc6..c6f1cc7e6f 100644
--- a/.github/workflows/deploy-hitl.yml
+++ b/.github/workflows/deploy-hitl.yml
@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'build/feat/hitl'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v1
+ uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.HITL_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml
index cadc1b5507..340b380dc9 100644
--- a/.github/workflows/docker-build.yml
+++ b/.github/workflows/docker-build.yml
@@ -32,13 +32,13 @@ jobs:
context: "web"
steps:
- name: Set up QEMU
- uses: docker/setup-qemu-action@v3
+ uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v3
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Build Docker Image
- uses: docker/build-push-action@v6
+ uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
push: false
context: "{{defaultContext}}:${{ matrix.context }}"
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index 06782b53c1..278e10bc04 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -9,6 +9,6 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- - uses: actions/labeler@v6
+ - uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with:
sync-labels: true
diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml
index d6653de950..ef2e3c7bb4 100644
--- a/.github/workflows/main-ci.yml
+++ b/.github/workflows/main-ci.yml
@@ -27,8 +27,8 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- - uses: actions/checkout@v6
- - uses: dorny/paths-filter@v3
+ - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+ - uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
id: changes
with:
filters: |
@@ -39,6 +39,7 @@ jobs:
web:
- 'web/**'
- '.github/workflows/web-tests.yml'
+ - '.github/actions/setup-web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml
index f9fbcba465..0278e1e0d3 100644
--- a/.github/workflows/pyrefly-diff-comment.yml
+++ b/.github/workflows/pyrefly-diff-comment.yml
@@ -21,7 +21,7 @@ jobs:
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
- uses: actions/github-script@v8
+ uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -49,7 +49,7 @@ jobs:
run: unzip -o pyrefly_diff.zip
- name: Post comment
- uses: actions/github-script@v8
+ uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml
index 14338e85b3..cceaf58789 100644
--- a/.github/workflows/pyrefly-diff.yml
+++ b/.github/workflows/pyrefly-diff.yml
@@ -17,12 +17,12 @@ jobs:
pull-requests: write
steps:
- name: Checkout PR branch
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
- uses: astral-sh/setup-uv@v5
+ uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
@@ -55,7 +55,7 @@ jobs:
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: pyrefly_diff
path: |
@@ -64,7 +64,7 @@ jobs:
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
- uses: actions/github-script@v8
+ uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml
index b15c26a096..c21331ec0d 100644
--- a/.github/workflows/semantic-pull-request.yml
+++ b/.github/workflows/semantic-pull-request.yml
@@ -16,6 +16,6 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check title
- uses: amannn/action-semantic-pull-request@v6.1.1
+ uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index b6df1d7e93..5cf52daed2 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- - uses: actions/stale@v10
+ - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
with:
days-before-issue-stale: 15
days-before-issue-close: 3
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index eb13c3d096..4168f890f5 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v47
+ uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
api/**
@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
- uses: astral-sh/setup-uv@v7
+ uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: false
python-version: "3.12"
@@ -67,36 +67,22 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v47
+ uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/style.yml
+ .github/actions/setup-web/**
- - name: Install pnpm
- uses: pnpm/action-setup@v4
- with:
- package_json_file: web/package.json
- run_install: false
-
- - name: Setup NodeJS
- uses: actions/setup-node@v6
+ - name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
- with:
- node-version: 22
- cache: pnpm
- cache-dependency-path: ./web/pnpm-lock.yaml
-
- - name: Web dependencies
- if: steps.changed-files.outputs.any_changed == 'true'
- working-directory: ./web
- run: pnpm install --frozen-lockfile
+ uses: ./.github/actions/setup-web
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
@@ -134,14 +120,14 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v47
+ uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
**.sh
@@ -152,7 +138,7 @@ jobs:
.editorconfig
- name: Super-linter
- uses: super-linter/super-linter/slim@v8
+ uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index d9a1168636..3fc351c0c2 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -21,12 +21,12 @@ jobs:
working-directory: sdks/nodejs-client
steps:
- - uses: actions/checkout@v6
+ - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Use Node.js
- uses: actions/setup-node@v6
+ uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
with:
node-version: 22
cache: ''
diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml
index b431c36a8b..ff07313ebe 100644
--- a/.github/workflows/translate-i18n-claude.yml
+++ b/.github/workflows/translate-i18n-claude.yml
@@ -38,7 +38,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@@ -48,18 +48,10 @@ jobs:
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- - name: Install pnpm
- uses: pnpm/action-setup@v4
+ - name: Setup web environment
+ uses: ./.github/actions/setup-web
with:
- package_json_file: web/package.json
- run_install: false
-
- - name: Set up Node.js
- uses: actions/setup-node@v6
- with:
- node-version: 22
- cache: pnpm
- cache-dependency-path: ./web/pnpm-lock.yaml
+ install-dependencies: "false"
- name: Detect changed files and generate diff
id: detect_changes
@@ -130,7 +122,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
- uses: anthropics/claude-code-action@v1
+ uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml
index 66a29453b4..1caaddd47a 100644
--- a/.github/workflows/trigger-i18n-sync.yml
+++ b/.github/workflows/trigger-i18n-sync.yml
@@ -21,7 +21,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
@@ -59,7 +59,7 @@ jobs:
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
- uses: peter-evans/repository-dispatch@v3
+ uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
with:
token: ${{ secrets.GITHUB_TOKEN }}
event-type: i18n-sync
diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml
index 7735afdaca..8cb7db7601 100644
--- a/.github/workflows/vdb-tests.yml
+++ b/.github/workflows/vdb-tests.yml
@@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Free Disk Space
- uses: endersonmenezes/free-disk-space@v3
+ uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v7
+ uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -60,7 +60,7 @@ jobs:
# tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
- uses: hoverkraft-tech/compose-action@v2.0.2
+ uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index 659620b2a9..33e9170b02 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -26,32 +26,19 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- - name: Install pnpm
- uses: pnpm/action-setup@v4
- with:
- package_json_file: web/package.json
- run_install: false
-
- - name: Setup Node.js
- uses: actions/setup-node@v6
- with:
- node-version: 22
- cache: pnpm
- cache-dependency-path: ./web/pnpm-lock.yaml
-
- - name: Install dependencies
- run: pnpm install --frozen-lockfile
+ - name: Setup web environment
+ uses: ./.github/actions/setup-web
- name: Run tests
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
- uses: actions/upload-artifact@v6
+ uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*
@@ -70,28 +57,15 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- - name: Install pnpm
- uses: pnpm/action-setup@v4
- with:
- package_json_file: web/package.json
- run_install: false
-
- - name: Setup Node.js
- uses: actions/setup-node@v6
- with:
- node-version: 22
- cache: pnpm
- cache-dependency-path: ./web/pnpm-lock.yaml
-
- - name: Install dependencies
- run: pnpm install --frozen-lockfile
+ - name: Setup web environment
+ uses: ./.github/actions/setup-web
- name: Download blob reports
- uses: actions/download-artifact@v6
+ uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: web/.vitest-reports
pattern: blob-report-*
@@ -419,7 +393,7 @@ jobs:
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
- uses: actions/upload-artifact@v6
+ uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: web-coverage-report
path: web/coverage
@@ -435,36 +409,22 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v6
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v47
+ uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/web-tests.yml
+ .github/actions/setup-web/**
- - name: Install pnpm
- uses: pnpm/action-setup@v4
- with:
- package_json_file: web/package.json
- run_install: false
-
- - name: Setup NodeJS
- uses: actions/setup-node@v6
+ - name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
- with:
- node-version: 22
- cache: pnpm
- cache-dependency-path: ./web/pnpm-lock.yaml
-
- - name: Web dependencies
- if: steps.changed-files.outputs.any_changed == 'true'
- working-directory: ./web
- run: pnpm install --frozen-lockfile
+ uses: ./.github/actions/setup-web
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template
index 700b815c3b..c3e2c50c52 100644
--- a/.vscode/launch.json.template
+++ b/.vscode/launch.json.template
@@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
- "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
+ "dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"--loglevel",
"INFO"
],
diff --git a/Makefile b/Makefile
index 0aff26b3e5..55871c86a7 100644
--- a/Makefile
+++ b/Makefile
@@ -68,8 +68,9 @@ lint:
@echo "✅ Linting complete"
type-check:
- @echo "📝 Running type checks (basedpyright + mypy)..."
+ @echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
+ @./dev/pyrefly-check-local
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
@@ -131,7 +132,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
- @echo " make type-check - Run type checks (basedpyright, mypy)"
+ @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/
Hello
", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert "".join(docs[0].page_content.split()) == "TitleHello" + + def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path): + file_path = tmp_path / "sample.html" + file_path.write_text(" \n ", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + + assert extractor._load_as_text() == "" diff --git a/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py new file mode 100644 index 0000000000..0b4c9bd809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py @@ -0,0 +1,47 @@ +from pytest_mock import MockerFixture + +from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor + + +class TestJinaReaderWebExtractor: + def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={ + "content": "markdown-content", + "url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "markdown-content" + assert docs[0].metadata == { + "source_url": "https://example.com", + "description": "desc", + "title": "title", + } + + def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value=None, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_non_crawl_mode_returns_empty(self, mocker: MockerFixture): + mock_get_crawl = mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={"content": "unused"}, + ) + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert extractor.extract() == [] + mock_get_crawl.assert_not_called() diff --git a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py index d4cf534c56..7e78c86c7d 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py @@ -1,8 +1,15 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.markdown_extractor as markdown_module from core.rag.extractor.markdown_extractor import MarkdownExtractor -def test_markdown_to_tups(): - markdown = """ +class TestMarkdownExtractor: + def test_markdown_to_tups(self): + markdown = """ this is some text without header # title 1 @@ -11,12 +18,113 @@ this is balabala text ## title 2 this is more specific text. """ - extractor = MarkdownExtractor(file_path="dummy_path") - updated_output = extractor.markdown_to_tups(markdown) - assert len(updated_output) == 3 - key, header_value = updated_output[0] - assert key == None - assert header_value.strip() == "this is some text without header" - title_1, value = updated_output[1] - assert title_1.strip() == "title 1" - assert value.strip() == "this is balabala text" + extractor = MarkdownExtractor(file_path="dummy_path") + updated_output = extractor.markdown_to_tups(markdown) + + assert len(updated_output) == 3 + key, header_value = updated_output[0] + assert key is None + assert header_value.strip() == "this is some text without header" + + title_1, value = updated_output[1] + assert title_1.strip() == "title 1" + assert value.strip() == "this is balabala text" + + def test_markdown_to_tups_keeps_code_block_headers_literal(self): + markdown = """# Header +before +```python +# this is not a heading +print('x') +``` +after +""" + extractor = MarkdownExtractor(file_path="dummy_path") + + tups = extractor.markdown_to_tups(markdown) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "# this is not a heading" in tups[1][1] + + def test_remove_images_and_hyperlinks(self): + extractor = MarkdownExtractor(file_path="dummy_path") + + with_images = "before ![[image.png]] after" + with_links = "[OpenAI](https://openai.com)" + + assert extractor.remove_images(with_images) == "before after" + assert extractor.remove_hyperlinks(with_links) == "OpenAI" + + def test_parse_tups_reads_file_and_applies_options(self, tmp_path): + markdown_file = tmp_path / "doc.md" + markdown_file.write_text("# Header\nText with [link](https://example.com) and ![[img.png]]", encoding="utf-8") + + extractor = MarkdownExtractor( + file_path=str(markdown_file), + remove_hyperlinks=True, + remove_images=True, + autodetect_encoding=False, + ) + + tups = extractor.parse_tups(str(markdown_file)) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "[link]" not in tups[1][1] + assert "img.png" not in tups[1][1] + + def test_parse_tups_autodetects_encoding_after_decode_error(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=True) + + calls: list[str | None] = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + if encoding == "bad-encoding": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + return "# H\ncontent" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + markdown_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad-encoding"), SimpleNamespace(encoding="utf-8")], + ) + + tups = extractor.parse_tups("dummy_path") + + assert len(tups) == 2 + assert calls == [None, "bad-encoding", "utf-8"] + + def test_parse_tups_decode_error_with_autodetect_disabled_raises(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=False) + + def raise_decode(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + + monkeypatch.setattr(Path, "read_text", raise_decode, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_parse_tups_other_exceptions_are_wrapped(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + + def raise_other(self, encoding=None): + raise OSError("disk error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_extract_builds_documents_for_header_and_non_header(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + monkeypatch.setattr(extractor, "parse_tups", lambda _: [(None, "plain"), ("Header", "value")]) + + docs = extractor.extract() + + assert [doc.page_content for doc in docs] == ["plain", "\n\nHeader\nvalue"] diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index 58bec7d19e..6daee11f8f 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -1,93 +1,499 @@ +from types import SimpleNamespace from unittest import mock +import httpx +import pytest from pytest_mock import MockerFixture from core.rag.extractor import notion_extractor -user_id = "user1" -database_id = "database1" -page_id = "page1" - -extractor = notion_extractor.NotionExtractor( - notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" -) - - -def _generate_page(page_title: str): - return { - "object": "page", - "id": page_id, - "properties": { - "Page": { - "type": "title", - "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], - } - }, - } - - -def _generate_block(block_id: str, block_type: str, block_text: str): - return { - "object": "block", - "id": block_id, - "parent": {"type": "page_id", "page_id": page_id}, - "type": block_type, - "has_children": False, - block_type: { - "rich_text": [ - { - "type": "text", - "text": {"content": block_text}, - "plain_text": block_text, - } - ] - }, - } - - -def _mock_response(data): +def _mock_response(data, status_code: int = 200, text: str = ""): response = mock.Mock() - response.status_code = 200 + response.status_code = status_code + response.text = text response.json.return_value = data return response -def _remove_multiple_new_lines(text): - while "\n\n" in text: - text = text.replace("\n\n", "\n") - return text.strip() +class TestNotionExtractorInitAndPublicMethods: + def test_init_with_explicit_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor._notion_access_token == "token" + + def test_init_falls_back_to_env_token_when_credential_lookup_fails(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", "env-token", raising=False) + + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + assert extractor._notion_access_token == "env-token" + + def test_init_raises_if_no_credential_and_no_env_token(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", None, raising=False) + + with pytest.raises(ValueError, match="Must specify `integration_token`"): + notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + def test_extract_updates_last_edited_and_loads_documents(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + update_mock = mock.Mock() + load_mock = mock.Mock(return_value=[SimpleNamespace(page_content="doc")]) + monkeypatch.setattr(extractor, "update_last_edited_time", update_mock) + monkeypatch.setattr(extractor, "_load_data_as_documents", load_mock) + + docs = extractor.extract() + + update_mock.assert_called_once_with(None) + load_mock.assert_called_once_with("obj", "page") + assert len(docs) == 1 + + def test_load_data_as_documents_page_database_and_invalid(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + monkeypatch.setattr(extractor, "_get_notion_block_data", lambda _: ["line1", "line2"]) + page_docs = extractor._load_data_as_documents("page-id", "page") + assert page_docs[0].page_content == "line1\nline2" + + monkeypatch.setattr(extractor, "_get_notion_database_data", lambda _: [SimpleNamespace(page_content="db")]) + db_docs = extractor._load_data_as_documents("db-id", "database") + assert db_docs[0].page_content == "db" + + with pytest.raises(ValueError, match="notion page type not supported"): + extractor._load_data_as_documents("obj", "unsupported") -def test_notion_page(mocker: MockerFixture): - texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] - mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]), - ], - "next_cursor": None, - } - mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page)) +class TestNotionDatabase: + def test_get_notion_database_data_parses_property_types_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) - page_docs = extractor._load_data_as_documents(page_id, "page") - assert len(page_docs) == 1 - content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" + first_page = { + "results": [ + { + "properties": { + "tags": { + "type": "multi_select", + "multi_select": [{"name": "A"}, {"name": "B"}], + }, + "title_prop": {"type": "title", "title": [{"plain_text": "Title"}]}, + "empty_title": {"type": "title", "title": []}, + "rich": {"type": "rich_text", "rich_text": [{"plain_text": "RichText"}]}, + "empty_rich": {"type": "rich_text", "rich_text": []}, + "select_prop": {"type": "select", "select": {"name": "Selected"}}, + "empty_select": {"type": "select", "select": None}, + "status_prop": {"type": "status", "status": {"name": "Open"}}, + "empty_status": {"type": "status", "status": None}, + "number_prop": {"type": "number", "number": 10}, + "dict_prop": {"type": "date", "date": {"start": "2024-01-01", "end": None}}, + }, + "url": "https://notion.so/page-1", + } + ], + "has_more": True, + "next_cursor": "cursor-2", + } + second_page = {"results": [], "has_more": False, "next_cursor": None} + + mock_post = mocker.patch("httpx.post", side_effect=[_mock_response(first_page), _mock_response(second_page)]) + + docs = extractor._get_notion_database_data("db-1", query_dict={"filter": {"x": 1}}) + + assert len(docs) == 1 + content = docs[0].page_content + assert "tags:['A', 'B']" in content + assert "title_prop:Title" in content + assert "rich:RichText" in content + assert "number_prop:10" in content + assert "dict_prop:start:2024-01-01" in content + assert "Row Page URL:https://notion.so/page-1" in content + assert mock_post.call_count == 2 + + def test_get_notion_database_data_handles_missing_results_and_empty_content(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.post", return_value=_mock_response({"results": None})) + assert extractor._get_notion_database_data("db-1") == [] + + def test_get_notion_database_data_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor._get_notion_database_data("db-1") -def test_notion_database(mocker: MockerFixture): - page_title_list = ["page1", "page2", "page3"] - mocked_notion_database = { - "object": "list", - "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None, - } - mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database)) - database_docs = extractor._load_data_as_documents(database_id, "database") - assert len(database_docs) == 1 - content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == "\n".join([f"Page:{i}" for i in page_title_list]) +class TestNotionBlocks: + def test_get_notion_block_data_success_with_table_headings_children_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + first_response = { + "results": [ + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + { + "type": "heading_1", + "id": "h1", + "has_children": False, + "heading_1": {"rich_text": [{"text": {"content": "Heading"}}]}, + }, + { + "type": "paragraph", + "id": "p1", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Paragraph"}}]}, + }, + { + "type": "child_page", + "id": "cp1", + "has_children": True, + "child_page": {"rich_text": []}, + }, + ], + "next_cursor": "cursor-2", + } + second_response = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "SubHeading"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(first_response), _mock_response(second_response)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE") + mocker.patch.object(extractor, "_read_block", return_value="CHILD") + + lines = extractor._get_notion_block_data("page-1") + + assert lines[0] == "TABLE\n\n" + assert "# Heading" in lines[1] + assert "Paragraph\nCHILD\n\n" in lines[2] + assert "## SubHeading" in lines[-1] + + def test_get_notion_block_data_handles_http_error_and_invalid_payload(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.request", side_effect=httpx.HTTPError("network")) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"bad": "payload"}, status_code=200)) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"results": []}, status_code=500, text="boom")) + with pytest.raises(ValueError, match="Error fetching Notion block data: boom"): + extractor._get_notion_block_data("page-1") + + def test_read_block_supports_heading_table_and_recursion(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + root_payload = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "Root"}}]}, + }, + { + "type": "paragraph", + "id": "child-block", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Parent"}}]}, + }, + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + ], + "next_cursor": None, + } + child_payload = { + "results": [ + { + "type": "paragraph", + "id": "leaf", + "has_children": False, + "paragraph": {"rich_text": [{"text": {"content": "Child"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(root_payload), _mock_response(child_payload)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE-MD") + + content = extractor._read_block("root") + + assert "## Root" in content + assert "Parent" in content + assert "Child" in content + assert "TABLE-MD" in content + + def test_read_block_breaks_on_missing_results(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + mocker.patch("httpx.request", return_value=_mock_response({"results": None, "next_cursor": None})) + + assert extractor._read_block("root") == "" + + def test_read_table_rows_formats_markdown_with_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + page_one = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [{"text": {"content": "H2"}}], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R1C1"}}], + [{"text": {"content": "R1C2"}}], + ] + } + }, + ], + "next_cursor": "next", + } + page_two = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R2C1"}}], + [{"text": {"content": "R2C2"}}], + ] + } + }, + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(page_one), _mock_response(page_two)]) + + markdown = extractor._read_table_rows("tbl-1") + + assert "| H1 | H2 |" in markdown + assert "| R1C1 | R1C2 |" in markdown + assert "| H1 | |" in markdown + assert "| R2C1 | R2C2 |" in markdown + + +class TestNotionMetadataAndCredentialMethods: + def test_update_last_edited_time_no_document_model(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor.update_last_edited_time(None) is None + + def test_update_last_edited_time_updates_document_and_commits(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + class FakeDocumentModel: + data_source_info = "data_source_info" + + update_calls = [] + + class FakeQuery: + def filter_by(self, **kwargs): + return self + + def update(self, payload): + update_calls.append(payload) + + class FakeSession: + committed = False + + def query(self, model): + assert model is FakeDocumentModel + return FakeQuery() + + def commit(self): + self.committed = True + + fake_db = SimpleNamespace(session=FakeSession()) + monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel) + monkeypatch.setattr(notion_extractor, "db", fake_db) + monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z") + + doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"}) + extractor.update_last_edited_time(doc_model) + + assert update_calls + assert fake_db.session.committed is True + + def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture): + extractor_page = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="page-id", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-05-01T00:00:00.000Z"}) + ) + + assert extractor_page.get_notion_last_edited_time() == "2025-05-01T00:00:00.000Z" + assert "pages/page-id" in request_mock.call_args[0][1] + + extractor_db = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="db-id", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-06-01T00:00:00.000Z"}) + ) + + assert extractor_db.get_notion_last_edited_time() == "2025-06-01T00:00:00.000Z" + assert "databases/db-id" in request_mock.call_args[0][1] + + def test_get_notion_last_edited_time_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor.get_notion_last_edited_time() + + def test_get_access_token_success_and_errors(self, monkeypatch): + with pytest.raises(Exception, match="No credential id found"): + notion_extractor.NotionExtractor._get_access_token("tenant", None) + + class FakeProviderServiceMissing: + def get_datasource_credentials(self, **kwargs): + return None + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceMissing) + with pytest.raises(Exception, match="No notion credential found"): + notion_extractor.NotionExtractor._get_access_token("tenant", "cred") + + class FakeProviderServiceFound: + def get_datasource_credentials(self, **kwargs): + return {"integration_secret": "token-from-credential"} + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceFound) + + assert notion_extractor.NotionExtractor._get_access_token("tenant", "cred") == "token-from-credential" diff --git a/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py new file mode 100644 index 0000000000..fb3c6e52c6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.text_extractor as text_module +from core.rag.extractor.text_extractor import TextExtractor + + +class TestTextExtractor: + def test_extract_success(self, tmp_path): + file_path = tmp_path / "data.txt" + file_path.write_text("hello world", encoding="utf-8") + + extractor = TextExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "hello world" + assert docs[0].metadata == {"source": str(file_path)} + + def test_extract_autodetect_success_after_decode_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + calls = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + if encoding == "bad": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + return "decoded text" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + text_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")], + ) + + docs = extractor.extract() + + assert docs[0].page_content == "decoded text" + assert calls == [None, "bad", "utf-8"] + + def test_extract_autodetect_all_fail_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + monkeypatch.setattr(text_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="latin-1")]) + + with pytest.raises(RuntimeError, match="all detected encodings failed"): + extractor.extract() + + def test_extract_decode_error_without_autodetect_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=False) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + + with pytest.raises(RuntimeError, match="specified encoding failed"): + extractor.extract() + + def test_extract_wraps_non_decode_exceptions(self, monkeypatch): + extractor = TextExtractor("dummy.txt") + + def raise_other(self, encoding=None): + raise OSError("io error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy.txt"): + extractor.extract() diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 0792ada194..64eb89590a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -3,9 +3,12 @@ import io import os import tempfile +from collections import UserDict from pathlib import Path from types import SimpleNamespace +from unittest.mock import MagicMock +import pytest from docx import Document from docx.oxml import OxmlElement from docx.oxml.ns import qn @@ -136,7 +139,7 @@ def test_extract_images_from_docx(monkeypatch): monkeypatch.setattr(we, "UploadFile", FakeUploadFile) # Patch external image fetcher - def fake_get(url: str): + def fake_get(url: str, **kwargs): assert url == "https://example.com/image.png" return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes) @@ -203,10 +206,8 @@ def test_extract_images_from_docx_uses_internal_files_url(): finally: # Restore original values - if original_files_url is not None: - dify_config.FILES_URL = original_files_url - if original_internal_files_url is not None: - dify_config.INTERNAL_FILES_URL = original_internal_files_url + dify_config.FILES_URL = original_files_url + dify_config.INTERNAL_FILES_URL = original_internal_files_url def test_extract_hyperlinks(monkeypatch): @@ -314,3 +315,405 @@ def test_extract_legacy_hyperlinks(monkeypatch): finally: if os.path.exists(tmp_path): os.remove(tmp_path) + + +def test_init_rejects_invalid_url_status(monkeypatch): + class FakeResponse: + status_code = 404 + content = b"" + closed = False + + def close(self): + self.closed = True + + fake_response = FakeResponse() + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=lambda url, **kwargs: fake_response)) + + with pytest.raises(ValueError, match="returned status code 404"): + WordExtractor("https://example.com/missing.docx", "tenant", "user") + + assert fake_response.closed is True + + +def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): + target_file = tmp_path / "expanded.docx" + target_file.write_bytes(b"docx") + + monkeypatch.setattr(we.os.path, "expanduser", lambda p: str(target_file)) + monkeypatch.setattr( + we.os.path, + "isfile", + lambda p: p == str(target_file), + ) + + extractor = WordExtractor("~/expanded.docx", "tenant", "user") + assert extractor.file_path == str(target_file) + + monkeypatch.setattr(we.os.path, "isfile", lambda p: False) + with pytest.raises(ValueError, match="is not a valid file or url"): + WordExtractor("not-a-file", "tenant", "user") + + +def test_del_closes_temp_file(): + extractor = object.__new__(WordExtractor) + extractor.temp_file = MagicMock() + + WordExtractor.__del__(extractor) + + extractor.temp_file.close.assert_called_once() + + +def test_extract_images_handles_invalid_external_cases(monkeypatch): + class FakeTargetRef: + def __contains__(self, item): + return item == "image" + + def split(self, sep): + return [None] + + rel_invalid_url = SimpleNamespace(is_external=True, target_ref="image-no-url") + rel_request_error = SimpleNamespace(is_external=True, target_ref="https://example.com/image-error") + rel_unknown_mime = SimpleNamespace(is_external=True, target_ref="https://example.com/image-unknown") + rel_internal_none_ext = SimpleNamespace(is_external=False, target_ref=FakeTargetRef(), target_part=object()) + + doc = SimpleNamespace( + part=SimpleNamespace( + rels={ + "r1": rel_invalid_url, + "r2": rel_request_error, + "r3": rel_unknown_mime, + "r4": rel_internal_none_ext, + } + ) + ) + + def fake_get(url, **kwargs): + if "image-error" in url: + raise RuntimeError("network") + return SimpleNamespace(status_code=200, headers={"Content-Type": "application/unknown"}, content=b"x") + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=MagicMock())) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None)) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + + extractor = object.__new__(WordExtractor) + extractor.tenant_id = "tenant" + extractor.user_id = "user" + + result = extractor._extract_images_from_docx(doc) + + assert result == {} + db_stub.session.commit.assert_called_once() + + +def test_table_to_markdown_and_parse_helpers(monkeypatch): + extractor = object.__new__(WordExtractor) + + table = SimpleNamespace( + rows=[ + SimpleNamespace(cells=[1, 2]), + SimpleNamespace(cells=[3, 4]), + ] + ) + parse_row_mock = MagicMock(side_effect=[["H1", "H2"], ["A", "B"]]) + monkeypatch.setattr(extractor, "_parse_row", parse_row_mock) + + markdown = extractor._table_to_markdown(table, {}) + assert markdown == "| H1 | H2 |\n| --- | --- |\n| A | B |" + + class FakeBlip: + def __init__(self, image_id): + self.image_id = image_id + + def get(self, key): + return self.image_id + + class FakeRunChild: + def __init__(self, blips, text=""): + self._blips = blips + self.text = text + self.tag = qn("w:r") + + def xpath(self, pattern): + if pattern == ".//a:blip": + return self._blips + return [] + + class FakeRun: + def __init__(self, element, paragraph): + # Mirror the subset used by _parse_cell_paragraph + self.element = element + self.text = getattr(element, "text", "") + + # Patch we.Run so our lightweight child objects work with the extractor + monkeypatch.setattr(we, "Run", FakeRun) + + image_part = object() + paragraph = SimpleNamespace( + _element=[ + FakeRunChild([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")], text=""), + FakeRunChild([], text="plain"), + ], + part=SimpleNamespace( + rels={ + "ext": SimpleNamespace(is_external=True), + "int": SimpleNamespace(is_external=False, target_part=image_part), + } + ), + ) + + image_map = {"ext": "EXT-IMG", image_part: "INT-IMG"} + assert extractor._parse_cell_paragraph(paragraph, image_map) == "EXT-IMGINT-IMGplain" + + cell = SimpleNamespace(paragraphs=[paragraph, paragraph]) + assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain" + + +def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch): + extractor = object.__new__(WordExtractor) + + ext_image_id = "ext-image" + int_embed_id = "int-embed" + shape_ext_id = "shape-ext" + shape_int_id = "shape-int" + + internal_part = object() + shape_internal_part = object() + + class Rels(UserDict): + def get(self, key, default=None): + if key == "link-bad": + raise RuntimeError("cannot resolve relation") + return super().get(key, default) + + rels = Rels( + { + ext_image_id: SimpleNamespace(is_external=True, target_ref="https://img/ext.png"), + int_embed_id: SimpleNamespace(is_external=False, target_part=internal_part), + shape_ext_id: SimpleNamespace(is_external=True, target_ref="https://img/shape.png"), + shape_int_id: SimpleNamespace(is_external=False, target_part=shape_internal_part), + "link-ok": SimpleNamespace(is_external=True, target_ref="https://example.com"), + } + ) + + image_map = { + ext_image_id: "[EXT]", + internal_part: "[INT]", + shape_ext_id: "[SHAPE_EXT]", + shape_internal_part: "[SHAPE_INT]", + } + + class FakeBlip: + def __init__(self, embed_id): + self.embed_id = embed_id + + def get(self, key): + return self.embed_id + + class FakeDrawing: + def __init__(self, embed_ids): + self.embed_ids = embed_ids + + def findall(self, pattern): + return [FakeBlip(embed_id) for embed_id in self.embed_ids] + + class FakeNode: + def __init__(self, text=None, attrs=None): + self.text = text + self._attrs = attrs or {} + + def get(self, key): + return self._attrs.get(key) + + class FakeShape: + def __init__(self, bin_id=None, img_id=None): + self.bin_id = bin_id + self.img_id = img_id + + def find(self, pattern): + if "binData" in pattern and self.bin_id: + return FakeNode( + text="shape", + attrs={"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id": self.bin_id}, + ) + if "imagedata" in pattern and self.img_id: + return FakeNode(attrs={"id": self.img_id}) + return None + + class FakeChild: + def __init__( + self, + tag, + text="", + fld_chars=None, + instr_texts=None, + drawings=None, + shapes=None, + attrs=None, + hyperlink_runs=None, + ): + self.tag = tag + self.text = text + self._fld_chars = fld_chars or [] + self._instr_texts = instr_texts or [] + self._drawings = drawings or [] + self._shapes = shapes or [] + self._attrs = attrs or {} + self._hyperlink_runs = hyperlink_runs or [] + + def findall(self, pattern): + if pattern == qn("w:fldChar"): + return self._fld_chars + if pattern == qn("w:instrText"): + return self._instr_texts + if pattern == qn("w:r"): + return self._hyperlink_runs + if pattern.endswith("}drawing"): + return self._drawings + if pattern.endswith("}pict"): + return self._shapes + return [] + + def get(self, key): + return self._attrs.get(key) + + class FakeRun: + def __init__(self, element, paragraph): + self.element = element + self.text = getattr(element, "text", "") + + paragraph_main = SimpleNamespace( + _element=[ + FakeChild( + qn("w:r"), + text="run-text", + drawings=[FakeDrawing([ext_image_id, int_embed_id])], + shapes=[FakeShape(bin_id=shape_ext_id, img_id=shape_int_id)], + ), + FakeChild( + qn("w:r"), + text="", + drawings=[], + shapes=[FakeShape(bin_id=shape_ext_id)], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-ok"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="LinkText")], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-bad"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="BrokenLink")], + ), + ] + ) + paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")]) + + fake_doc = SimpleNamespace( + part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}), + paragraphs=[paragraph_main, paragraph_empty], + tables=[SimpleNamespace(rows=[])], + element=SimpleNamespace( + body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")] + ), + ) + + monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc) + monkeypatch.setattr(we, "Run", FakeRun) + monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map) + monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN") + logger_exception = MagicMock() + monkeypatch.setattr(we.logger, "exception", logger_exception) + + content = extractor.parse_docx("dummy.docx") + + assert "[EXT]" in content + assert "[INT]" in content + assert "[SHAPE_EXT]" in content + assert "[LinkText](https://example.com)" in content + assert "BrokenLink" in content + assert "TABLE-MARKDOWN" in content + logger_exception.assert_called_once() + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_http(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + # Build modern hyperlink inside table cell + r_id = "rIdHttp1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + # Relationship for external http link + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[Dify](https://dify.ai)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + r_id = "rIdMail1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "john@test.com" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "mailto:john@test.com", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[john@test.com](mailto:john@test.com)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py new file mode 100644 index 0000000000..26ce333e11 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py @@ -0,0 +1,300 @@ +"""Unit tests for unstructured extractors and their local/API partitioning paths.""" + +import base64 +import sys +import types +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.unstructured.unstructured_epub_extractor as epub_module +from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor +from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor +from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor +from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor +from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor +from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor +from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor +from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor + + +def _register_module(monkeypatch: pytest.MonkeyPatch, name: str, **attrs: object) -> types.ModuleType: + module = types.ModuleType(name) + for k, v in attrs.items(): + setattr(module, k, v) + monkeypatch.setitem(sys.modules, name, module) + return module + + +def _register_unstructured_packages(monkeypatch: pytest.MonkeyPatch) -> None: + _register_module(monkeypatch, "unstructured", __path__=[]) + _register_module(monkeypatch, "unstructured.partition", __path__=[]) + _register_module(monkeypatch, "unstructured.chunking", __path__=[]) + _register_module(monkeypatch, "unstructured.file_utils", __path__=[]) + + +def _install_chunk_by_title(monkeypatch: pytest.MonkeyPatch, chunks: list[SimpleNamespace]) -> None: + _register_unstructured_packages(monkeypatch) + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + return chunks + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + +class TestUnstructuredMarkdownMsgXml: + def test_markdown_extractor_without_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" chunk-1 "), SimpleNamespace(text=" chunk-2 ")]) + _register_module( + monkeypatch, "unstructured.partition.md", partition_md=lambda filename: [SimpleNamespace(text="x")] + ) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md").extract() + + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_markdown_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" via-api ")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="ignored")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "via-api" + assert calls == {"filename": "/tmp/file.md", "api_url": "https://u", "api_key": "k"} + + def test_msg_extractor_local(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + _register_module( + monkeypatch, "unstructured.partition.msg", partition_msg=lambda filename: [SimpleNamespace(text="x")] + ) + + assert UnstructuredMsgExtractor("/tmp/file.msg").extract()[0].page_content == "msg-doc" + + def test_msg_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredMsgExtractor("/tmp/file.msg", api_url="https://u", api_key="k").extract()[0].page_content + == "msg-doc" + ) + assert calls["filename"] == "/tmp/file.msg" + + def test_xml_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="xml-doc")]) + + xml_calls = {} + + def partition_xml(filename, xml_keep_tags): + xml_calls.update({"filename": filename, "xml_keep_tags": xml_keep_tags}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.xml", partition_xml=partition_xml) + + assert UnstructuredXmlExtractor("/tmp/file.xml").extract()[0].page_content == "xml-doc" + assert xml_calls == {"filename": "/tmp/file.xml", "xml_keep_tags": True} + + api_calls = {} + + def partition_via_api(filename, api_url, api_key): + api_calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredXmlExtractor("/tmp/file.xml", api_url="https://u", api_key="k").extract()[0].page_content + == "xml-doc" + ) + assert api_calls["filename"] == "/tmp/file.xml" + + +class TestUnstructuredEmailAndEpub: + def test_email_extractor_local_decodes_html_and_suppresses_decode_errors(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + captured = {} + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + captured["elements"] = list(elements) + return [SimpleNamespace(text=" chunked-email ")] + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + html = "Hello Email
" + encoded_html = base64.b64encode(html.encode("utf-8")).decode("utf-8") + bad_base64 = "not-base64" + + elements = [SimpleNamespace(text=encoded_html), SimpleNamespace(text=bad_base64)] + _register_module(monkeypatch, "unstructured.partition.email", partition_email=lambda filename: elements) + + docs = UnstructuredEmailExtractor("/tmp/file.eml").extract() + + assert docs[0].page_content == "chunked-email" + chunk_elements = captured["elements"] + assert "Hello Email" in chunk_elements[0].text + assert chunk_elements[1].text == bad_base64 + + def test_email_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="api-email")]) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="abc")], + ) + + docs = UnstructuredEmailExtractor("/tmp/file.eml", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "api-email" + + def test_epub_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="epub-doc")]) + + calls = {"download": 0, "partition": 0} + + def fake_download_pandoc(): + calls["download"] += 1 + + def partition_epub(filename, xml_keep_tags): + calls["partition"] += 1 + assert xml_keep_tags is True + return [SimpleNamespace(text="x")] + + monkeypatch.setattr(epub_module.pypandoc, "download_pandoc", fake_download_pandoc) + _register_module(monkeypatch, "unstructured.partition.epub", partition_epub=partition_epub) + + docs = UnstructuredEpubExtractor("/tmp/file.epub").extract() + + assert docs[0].page_content == "epub-doc" + assert calls == {"download": 1, "partition": 1} + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="x")], + ) + + docs = UnstructuredEpubExtractor("/tmp/file.epub", api_url="https://u", api_key="k").extract() + assert docs[0].page_content == "epub-doc" + + +class TestUnstructuredPPTAndPPTX: + def test_ppt_extractor_requires_api_url(self): + with pytest.raises(NotImplementedError, match="Unstructured API Url is not configured"): + UnstructuredPPTExtractor("/tmp/file.ppt").extract() + + def test_ppt_extractor_groups_text_by_page(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="A", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="B", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="skip", metadata=SimpleNamespace(page_number=None)), + SimpleNamespace(text="C", metadata=SimpleNamespace(page_number=2)), + ], + ) + + docs = UnstructuredPPTExtractor("/tmp/file.ppt", api_url="https://u", api_key="k").extract() + + assert [doc.page_content for doc in docs] == ["A\nB", "C"] + + def test_pptx_extractor_local_and_api(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.pptx", + partition_pptx=lambda filename: [ + SimpleNamespace(text="P1", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="P2", metadata=SimpleNamespace(page_number=2)), + SimpleNamespace(text="Skip", metadata=SimpleNamespace(page_number=None)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx").extract() + assert [doc.page_content for doc in docs] == ["P1", "P2"] + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="X", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="Y", metadata=SimpleNamespace(page_number=1)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx", api_url="https://u", api_key="k").extract() + assert [doc.page_content for doc in docs] == ["X\nY"] + + +class TestUnstructuredWord: + def _install_doc_modules(self, monkeypatch, version: str, filetype_value): + _register_unstructured_packages(monkeypatch) + + class FileType: + DOC = "doc" + + _register_module(monkeypatch, "unstructured.__version__", __version__=version) + _register_module( + monkeypatch, + "unstructured.file_utils.filetype", + FileType=FileType, + detect_filetype=lambda filename: filetype_value, + ) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="api-doc")], + ) + _register_module( + monkeypatch, + "unstructured.partition.docx", + partition_docx=lambda filename: [SimpleNamespace(text="docx-doc")], + ) + _register_module( + monkeypatch, + "unstructured.chunking.title", + chunk_by_title=lambda elements, max_characters, combine_text_under_n_chars: [ + SimpleNamespace(text="chunk-1"), + SimpleNamespace(text="chunk-2"), + ], + ) + + def test_word_extractor_rejects_doc_on_old_unstructured_version(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="doc") + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + + def test_word_extractor_doc_and_docx_paths(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.11", filetype_value="doc") + + docs = UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + self._install_doc_modules(monkeypatch, version="0.5.0", filetype_value="not-doc") + docs = UnstructuredWordExtractor("/tmp/file.docx", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_word_extractor_magic_import_error_fallback_to_extension(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="not-used") + monkeypatch.setitem(sys.modules, "magic", None) + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() diff --git a/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py new file mode 100644 index 0000000000..d758be218a --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py @@ -0,0 +1,434 @@ +"""Unit tests for WaterCrawl client, provider, and extractor behavior.""" + +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import core.rag.extractor.watercrawl.client as client_module +from core.rag.extractor.watercrawl.client import BaseAPIClient, WaterCrawlAPIClient +from core.rag.extractor.watercrawl.exceptions import ( + WaterCrawlAuthenticationError, + WaterCrawlBadRequestError, + WaterCrawlPermissionError, +) +from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor +from core.rag.extractor.watercrawl.provider import WaterCrawlProvider + + +def _response( + status_code: int, + json_data: dict[str, Any] | None = None, + content_type: str = "application/json", + content: bytes = b"", + text: str = "", +) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.headers = {"Content-Type": content_type} + response.content = content + response.text = text + response.json.return_value = json_data if json_data is not None else {} + response.raise_for_status.return_value = None + response.close.return_value = None + return response + + +class TestWaterCrawlExceptions: + def test_bad_request_error_properties_and_string(self): + response = _response(400, {"message": "bad request", "errors": {"url": ["invalid"]}}) + + err = WaterCrawlBadRequestError(response) + parsed_errors = json.loads(err.flat_errors) + + assert err.status_code == 400 + assert err.message == "bad request" + assert "url" in parsed_errors + assert any("invalid" in str(item) for item in parsed_errors["url"]) + assert "WaterCrawlBadRequestError" in str(err) + + def test_permission_and_authentication_error_strings(self): + response = _response(403, {"message": "quota exceeded", "errors": {}}) + + permission = WaterCrawlPermissionError(response) + authentication = WaterCrawlAuthenticationError(response) + + assert "exceeding your WaterCrawl API limits" in str(permission) + assert "API key is invalid or expired" in str(authentication) + + +class TestBaseAPIClient: + def test_init_session_builds_expected_headers(self, monkeypatch): + captured = {} + + def fake_client(**kwargs): + captured.update(kwargs) + return "session" + + monkeypatch.setattr(client_module.httpx, "Client", fake_client) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client.session == "session" + assert captured["headers"]["X-API-Key"] == "k" + assert captured["headers"]["User-Agent"] == "WaterCrawl-Plugin" + + def test_request_stream_and_non_stream_paths(self, monkeypatch): + class FakeSession: + def __init__(self): + self.request_calls = [] + self.build_calls = [] + self.send_calls = [] + + def request(self, method, url, params=None, json=None, **kwargs): + self.request_calls.append((method, url, params, json, kwargs)) + return "non-stream-response" + + def build_request(self, method, url, params=None, json=None): + req = (method, url, params, json) + self.build_calls.append(req) + return req + + def send(self, request, stream=False, **kwargs): + self.send_calls.append((request, stream, kwargs)) + return "stream-response" + + fake_session = FakeSession() + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: fake_session) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client._request("GET", "/v1/items", query_params={"a": 1}) == "non-stream-response" + assert fake_session.request_calls[0][1] == "https://watercrawl.dev/v1/items" + + assert client._request("GET", "/v1/items", stream=True) == "stream-response" + assert fake_session.build_calls + assert fake_session.send_calls[0][1] is True + + def test_http_method_helpers_delegate_to_request(self, monkeypatch): + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: MagicMock()) + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + calls = [] + + def fake_request(method, endpoint, query_params=None, data=None, **kwargs): + calls.append((method, endpoint, query_params, data)) + return "ok" + + monkeypatch.setattr(client, "_request", fake_request) + + assert client._get("/a") == "ok" + assert client._post("/b", data={"x": 1}) == "ok" + assert client._put("/c", data={"x": 2}) == "ok" + assert client._delete("/d") == "ok" + assert client._patch("/e", data={"x": 3}) == "ok" + assert [c[0] for c in calls] == ["GET", "POST", "PUT", "DELETE", "PATCH"] + + +class TestWaterCrawlAPIClient: + def test_process_eventstream_and_download(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = MagicMock() + response.iter_lines.return_value = [ + b"event: keep-alive", + b'data: {"type":"result","data":{"result":"http://x"}}', + b'data: {"type":"log","data":{"msg":"ok"}}', + ] + + monkeypatch.setattr(client, "download_result", lambda data: {"result": {"markdown": "body"}, "url": "u"}) + + events = list(client.process_eventstream(response, download=True)) + + assert events[0]["data"]["result"]["markdown"] == "body" + assert events[1]["type"] == "log" + response.close.assert_called_once() + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (401, WaterCrawlAuthenticationError), + (403, WaterCrawlPermissionError), + (422, WaterCrawlBadRequestError), + ], + ) + def test_process_response_error_statuses(self, status: int, expected_exception: type[Exception]): + client = WaterCrawlAPIClient(api_key="k") + + with pytest.raises(expected_exception): + client.process_response(_response(status, {"message": "bad", "errors": {"url": ["x"]}})) + + def test_process_response_204_returns_none(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(204, None)) is None + + def test_process_response_json_payloads(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(200, {"ok": True})) == {"ok": True} + assert client.process_response(_response(200, None)) == {} + + def test_process_response_octet_stream_returns_bytes(self): + client = WaterCrawlAPIClient(api_key="k") + assert ( + client.process_response(_response(200, content_type="application/octet-stream", content=b"bin")) == b"bin" + ) + + def test_process_response_event_stream_returns_generator(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + generator = (item for item in [{"type": "result", "data": {}}]) + monkeypatch.setattr(client, "process_eventstream", lambda response, download=False: generator) + assert client.process_response(_response(200, content_type="text/event-stream")) is generator + + def test_process_response_unknown_content_type_raises(self): + client = WaterCrawlAPIClient(api_key="k") + with pytest.raises(Exception, match="Unknown response type"): + client.process_response(_response(200, content_type="text/plain", text="x")) + + def test_process_response_uses_raise_for_status(self): + client = WaterCrawlAPIClient(api_key="k") + response = _response(500, {"message": "server"}) + response.raise_for_status.side_effect = RuntimeError("http error") + + with pytest.raises(RuntimeError, match="http error"): + client.process_response(response) + + def test_endpoint_wrappers(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda resp: "processed") + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "get-resp") + monkeypatch.setattr(client, "_post", lambda *args, **kwargs: "post-resp") + monkeypatch.setattr(client, "_delete", lambda *args, **kwargs: "delete-resp") + + assert client.get_crawl_requests_list() == "processed" + assert client.get_crawl_request("id") == "processed" + assert client.create_crawl_request(url="https://x") == "processed" + assert client.stop_crawl_request("id") == "processed" + assert client.download_crawl_request("id") == "processed" + assert client.get_crawl_request_results("id") == "processed" + + def test_monitor_crawl_request_generator_and_validation(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda _: (x for x in [{"type": "result", "data": 1}])) + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "stream-resp") + + events = list(client.monitor_crawl_request("job-1", prefetched=True)) + assert events == [{"type": "result", "data": 1}] + + monkeypatch.setattr(client, "process_response", lambda _: [{"type": "result"}]) + with pytest.raises(ValueError, match="Generator expected"): + list(client.monitor_crawl_request("job-1")) + + def test_scrape_url_sync_and_async(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + monkeypatch.setattr(client, "create_crawl_request", lambda **kwargs: {"uuid": "job-1"}) + + async_result = client.scrape_url("https://example.com", sync=False) + assert async_result == {"uuid": "job-1"} + + monkeypatch.setattr( + client, + "monitor_crawl_request", + lambda item_id, prefetched: iter( + [{"type": "log", "data": {}}, {"type": "result", "data": {"url": "https://example.com"}}] + ), + ) + sync_result = client.scrape_url("https://example.com", sync=True) + assert sync_result == {"url": "https://example.com"} + + def test_download_result_fetches_json_and_closes(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = _response(200, {"markdown": "body"}) + monkeypatch.setattr(client_module.httpx, "get", lambda *args, **kwargs: response) + + result = client.download_result({"result": "https://example.com/result.json"}) + + assert result["result"] == {"markdown": "body"} + response.close.assert_called_once() + + +class TestWaterCrawlProvider: + def test_crawl_url_builds_options_and_min_wait_time(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + captured_kwargs = {} + + def create_crawl_request_spy(**kwargs): + captured_kwargs.update(kwargs) + return {"uuid": "job-1"} + + monkeypatch.setattr(provider.client, "create_crawl_request", create_crawl_request_spy) + + result = provider.crawl_url( + "https://example.com", + { + "crawl_sub_pages": True, + "limit": 5, + "max_depth": 2, + "includes": "a,b", + "excludes": "x,y", + "exclude_tags": "nav,footer", + "include_tags": "main", + "wait_time": 100, + "only_main_content": False, + }, + ) + + assert result == {"status": "active", "job_id": "job-1"} + assert captured_kwargs["url"] == "https://example.com" + assert captured_kwargs["spider_options"] == { + "max_depth": 2, + "page_limit": 5, + "allowed_domains": [], + "exclude_paths": ["x", "y"], + "include_paths": ["a", "b"], + } + assert captured_kwargs["page_options"]["exclude_tags"] == ["nav", "footer"] + assert captured_kwargs["page_options"]["include_tags"] == ["main"] + assert captured_kwargs["page_options"]["only_main_content"] is False + assert captured_kwargs["page_options"]["wait_time"] == 1000 + + def test_get_crawl_status_active_and_completed(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "running", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 3}}, + "number_of_documents": 1, + "duration": "00:00:01.500000", + }, + ) + + active = provider.get_crawl_status("job-1") + assert active["status"] == "active" + assert active["data"] == [] + assert active["time_consuming"] == pytest.approx(1.5) + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "completed", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 2}}, + "number_of_documents": 2, + "duration": "00:00:02.000000", + }, + ) + monkeypatch.setattr(provider, "_get_results", lambda crawl_request_id, query_params=None: iter([{"url": "u"}])) + + completed = provider.get_crawl_status("job-2") + assert completed["status"] == "completed" + assert completed["data"] == [{"url": "u"}] + + def test_get_crawl_url_data_and_scrape(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr(provider, "scrape_url", lambda url: {"source_url": url}) + assert provider.get_crawl_url_data("", "https://example.com") == {"source_url": "https://example.com"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([{"source_url": "u1"}])) + assert provider.get_crawl_url_data("job", "u1") == {"source_url": "u1"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([])) + assert provider.get_crawl_url_data("job", "u1") is None + + def test_structure_data_validation_and_get_results_pagination(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + with pytest.raises(ValueError, match="Invalid result object"): + provider._structure_data({"result": "not-a-dict"}) + + structured = provider._structure_data( + { + "url": "https://example.com", + "result": { + "metadata": {"title": "Title", "description": "Desc"}, + "markdown": "Body", + }, + } + ) + assert structured["title"] == "Title" + assert structured["markdown"] == "Body" + + responses = [ + { + "results": [ + { + "url": "https://a", + "result": {"metadata": {"title": "A", "description": "DA"}, "markdown": "MA"}, + } + ], + "next": "next-page", + }, + {"results": [], "next": None}, + ] + + monkeypatch.setattr( + provider.client, + "get_crawl_request_results", + lambda crawl_request_id, page, page_size, query_params: responses.pop(0), + ) + + results = list(provider._get_results("job-1")) + assert len(results) == 1 + assert results[0]["source_url"] == "https://a" + + def test_scrape_url_uses_client_and_structure(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + monkeypatch.setattr( + provider.client, "scrape_url", lambda **kwargs: {"result": {"metadata": {}, "markdown": "m"}, "url": "u"} + ) + + result = provider.scrape_url("u") + + assert result["source_url"] == "u" + + +class TestWaterCrawlWebExtractor: + def test_extract_crawl_and_scrape_modes(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: { + "markdown": "crawl", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_scrape_url_data", + lambda provider, url, tenant_id, only_main_content: { + "markdown": "scrape", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + + crawl_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + scrape_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert crawl_extractor.extract()[0].page_content == "crawl" + assert scrape_extractor.extract()[0].page_content == "scrape" + + def test_extract_crawl_returns_empty_when_service_returns_none(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: None, + ) + + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_unknown_mode_returns_empty(self): + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="other") + + assert extractor.extract() == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/conftest.py b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py new file mode 100644 index 0000000000..2a3860e107 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py @@ -0,0 +1,33 @@ +from contextlib import AbstractContextManager, nullcontext +from typing import Any + +import pytest + + +class _FakeFlaskApp: + def app_context(self) -> AbstractContextManager[None]: + return nullcontext() + + +class _FakeExecutor: + def __init__(self, future: Any) -> None: + self._future = future + + def __enter__(self) -> "_FakeExecutor": + return self + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool: + return False + + def submit(self, func: object, preview: object) -> Any: + return self._future + + +@pytest.fixture +def fake_flask_app() -> _FakeFlaskApp: + return _FakeFlaskApp() + + +@pytest.fixture +def fake_executor_cls() -> type[_FakeExecutor]: + return _FakeExecutor diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py new file mode 100644 index 0000000000..2451db70b6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -0,0 +1,629 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from dify_graph.model_runtime.entities.model_entities import ModelFeature + + +class TestParagraphIndexProcessor: + @pytest.fixture + def processor(self) -> ParagraphIndexProcessor: + return ParagraphIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def _llm_result(self, content: str = "summary") -> LLMResult: + return LLMResult( + model="llm-model", + message=AssistantPromptMessage(content=content), + usage=LLMUsage.empty_usage(), + ) + + def test_extract_forwards_automatic_flag(self, processor: ParagraphIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected_docs + docs = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParagraphIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + rules_without_segmentation = SimpleNamespace(segmentation=None) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=rules_without_segmentation, + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".first", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + return_value=".first", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + ): + documents = processor.transform([source_document], process_rule=process_rule) + + assert len(documents) == 1 + assert documents[0].page_content == "first" + assert documents[0].attachments is not None + assert documents[0].metadata["doc_hash"] == "hash" + + def test_transform_automatic_mode_uses_default_rules(self, processor: ParagraphIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content="text", metadata={})] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ) as mock_validate, + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_get_content_files", return_value=[]), + ): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "automatic"}) + + assert mock_validate.call_count == 1 + + def test_load_creates_vector_and_multimodal_when_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + docs = [Document(page_content="chunk", metadata={})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + vector = mock_vector_cls.return_value + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + mock_keyword_cls.assert_not_called() + + def test_load_uses_keyword_add_texts_with_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs, keywords_list=["k1", "k2"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs, keywords_list=["k1", "k2"]) + + def test_load_uses_keyword_add_texts_without_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs) + + def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_economy_deletes_summaries_and_keywords( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + mock_keyword_cls.return_value.delete.assert_called_once() + + def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.clean(dataset, ["node-2"], with_keywords=True) + + mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"]) + + def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9) + rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [accepted, rejected] + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + def test_index_list_chunks_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, ["chunk-1", "chunk-2"]) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_list_chunks_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.index(dataset, dataset_document, ["chunk-3"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once() + + def test_index_multimodal_structure_handles_files_and_account_lookup( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + chunk_with_files = SimpleNamespace( + content="content-1", + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + chunk_without_files = SimpleNamespace(content="content-2", files=None) + structure = SimpleNamespace(general_chunks=[chunk_with_files, chunk_without_files]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector"), + ): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + assert mock_files.call_count == 1 + + def test_index_multimodal_structure_requires_valid_account( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + structure = SimpleNamespace(general_chunks=[SimpleNamespace(content="content", files=None)]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + def test_format_preview_validates_chunk_shape(self, processor: ParagraphIndexProcessor) -> None: + preview = processor.format_preview(["chunk-1", "chunk-2"]) + assert preview["chunk_structure"] == "text_model" + assert preview["total_segments"] == 2 + + with pytest.raises(ValueError, match="Chunks is not a list"): + processor.format_preview({"not": "a-list"}) + + def test_generate_summary_preview_success_and_failure(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())): + result = processor.generate_summary_preview( + "tenant-1", preview_items, {"enable": True}, doc_language="English" + ) + assert all(item.summary == "summary" for item in result) + + with patch.object(processor, "generate_summary", side_effect=RuntimeError("summary failed")): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", [PreviewDetail(content="chunk-1")], {"enable": True}) + + def test_generate_summary_preview_fallback_without_flask_context(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())), + ): + result = processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_timeout( + self, processor: ParagraphIndexProcessor, fake_executor_cls: type + ) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + future.cancel.assert_called_once() + + def test_generate_summary_validates_input(self) -> None: + with pytest.raises(ValueError, match="must be enabled"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": False}) + + with pytest.raises(ValueError, match="model_name and model_provider_name"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True}) + + def test_generate_summary_text_only_flow(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + model_instance.invoke_llm.return_value = self._llm_result("text summary") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota", + side_effect=RuntimeError("quota"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, usage = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + document_language="English", + ) + + assert summary == "text summary" + assert isinstance(usage, LLMUsage) + mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota") + + def test_generate_summary_handles_vision_and_image_conversion(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = self._llm_result("vision summary") + image_file = SimpleNamespace() + image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch.object( + ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[image_file] + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[]) as mock_extract_text, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + return_value=image_content, + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, _ = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + segment_id="seg-1", + ) + + assert summary == "vision summary" + mock_extract_text.assert_not_called() + + def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = object() + image_file = SimpleNamespace() + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DEFAULT_GENERATOR_SUMMARY_PROMPT", + "Prompt {missing}", + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[]), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[image_file]), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + side_effect=RuntimeError("bad image"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + with pytest.raises(ValueError, match="Expected LLMResult"): + ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + ) + + mock_logger.warning.assert_called_with( + "Failed to convert image file to prompt message content: %s", "bad image" + ) + + def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None: + text = ( + " " + " " + "" + ) + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + non_image_upload = SimpleNamespace( + id="22222222-2222-2222-2222-222222222222", + tenant_id="tenant-1", + name="file.txt", + mime_type="text/plain", + extension="txt", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload, non_image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + return_value=SimpleNamespace(id="file-1"), + ) as mock_builder, + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert len(files) == 1 + assert mock_builder.call_count == 1 + mock_logger.warning.assert_not_called() + + def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None: + assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here") == [] + + def test_extract_images_from_text_logs_when_build_fails(self) -> None: + text = "" + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + side_effect=RuntimeError("build failed"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert files == [] + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments(self) -> None: + image_upload = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="", + size=1, + key="k1", + ) + bad_upload = SimpleNamespace( + id="file-2", + name="broken", + extension=None, + mime_type="image/png", + source_url="", + size=1, + key="k2", + ) + non_image_upload = SimpleNamespace( + id="file-3", + name="text", + extension="txt", + mime_type="text/plain", + source_url="", + size=1, + key="k3", + ) + execute_result = Mock() + execute_result.all.return_value = [(None, image_upload), (None, bad_upload), (None, non_image_upload)] + session = Mock() + session.execute.return_value = execute_result + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert len(files) == 1 + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments_empty(self) -> None: + execute_result = Mock() + execute_result.all.return_value = [] + session = Mock() + session.execute.return_value = execute_result + + with patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session): + empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert empty_files == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py new file mode 100644 index 0000000000..abe40f05d1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -0,0 +1,523 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.models.document import AttachmentDocument, ChildDocument, Document +from services.entities.knowledge_entities.knowledge_entities import ParentMode + + +class TestParentChildIndexProcessor: + @pytest.fixture + def processor(self) -> ParentChildIndexProcessor: + return ParentChildIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + document.dataset_process_rule_id = None + return document + + def _segmentation(self) -> SimpleNamespace: + return SimpleNamespace(max_tokens=200, chunk_overlap=10, separator="\n") + + def _paragraph_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + segmentation=self._segmentation(), + subchunk_segmentation=self._segmentation(), + ) + + def _full_doc_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.FULL_DOC, segmentation=None, subchunk_segmentation=self._segmentation() + ) + + def test_extract_forwards_automatic_flag(self, processor: ParentChildIndexProcessor) -> None: + extract_setting = Mock() + expected = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected + documents = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert documents == expected + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParentChildIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_paragraph_requires_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(parent_mode=ParentMode.PARAGRAPH, segmentation=None) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", return_value=rules + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_paragraph_builds_parent_and_child_docs(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".parent", metadata={}), + Document(page_content=" ", metadata={}), + ] + parent_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + child_docs = [ChildDocument(page_content="child-1", metadata={"dataset_id": "dataset-1"})] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + return_value=".parent", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + ): + result = processor.transform( + [parent_document], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=False, + ) + + assert len(result) == 1 + assert result[0].page_content == "parent" + assert result[0].children == child_docs + assert result[0].attachments is not None + + def test_transform_preview_returns_after_ten_parent_chunks(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content=f"chunk-{i}", metadata={}) for i in range(10)] + documents = [ + Document(page_content="doc-1", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"dataset_id": "dataset-1", "document_id": "doc-2"}), + ] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch.object(processor, "_split_child_nodes", return_value=[]), + ): + result = processor.transform( + documents, + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 10 + + def test_transform_full_doc_mode_trims_children_for_preview(self, processor: ParentChildIndexProcessor) -> None: + docs = [ + Document(page_content="first", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="second", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + ] + child_docs = [ChildDocument(page_content=f"child-{i}", metadata={}) for i in range(5)] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._full_doc_rules(), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.dify_config.CHILD_CHUNKS_PREVIEW_NUMBER", + 2, + ), + ): + result = processor.transform( + docs, + process_rule={"mode": "hierarchical", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 1 + assert len(result[0].children or []) == 2 + assert result[0].attachments is not None + + def test_load_creates_vectors_for_child_docs(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + parent_doc = Document( + page_content="parent", + metadata={}, + children=[ + ChildDocument(page_content="child-1", metadata={}), + ChildDocument(page_content="child-2", metadata={}), + ], + ) + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, [parent_doc], multimodal_documents=multimodal_docs) + + assert vector.create.call_count == 1 + formatted_docs = vector.create.call_args[0][0] + assert len(formatted_docs) == 2 + assert all(isinstance(doc, Document) for doc in formatted_docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + delete_query = Mock() + where_query = Mock() + where_query.delete.return_value = 2 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean( + dataset, + ["node-1"], + delete_child_chunks=True, + precomputed_child_node_ids=["child-1", "child-2"], + ) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_queries_child_ids_when_not_precomputed( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + child_query = Mock() + child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)] + session = Mock() + session.query.return_value = child_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_child_chunks=False) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + + def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + where_query = Mock() + where_query.delete.return_value = 3 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_child_chunks=True) + + vector.delete.assert_called_once() + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.session_factory.create_session", + return_value=session_ctx, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, ["node-1"], delete_summaries=True, precomputed_child_node_ids=[]) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + + def test_clean_deletes_all_summaries_when_node_ids_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + + def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8) + low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [ok_result, low_result] + docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].page_content == "keep" + assert docs[0].metadata["score"] == 0.8 + + def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=None) + + with pytest.raises(ValueError, match="No subchunk segmentation found"): + processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None) + + def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=self._segmentation()) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".child-1", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + ): + child_docs = processor._split_child_nodes( + Document(page_content="parent", metadata={}), rules, "custom", None + ) + + assert len(child_docs) == 1 + assert child_docs[0].page_content == "child-1" + assert child_docs[0].metadata["doc_hash"] == "hash" + + def test_index_creates_process_rule_segments_and_vectors( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[ + SimpleNamespace( + parent_content="parent text", + child_contents=["child-1", "child-2"], + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + ], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + side_effect=lambda text: f"hash-{text}", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + assert dataset_document.dataset_process_rule_id == "rule-1" + session.add.assert_called_once_with(dataset_rule) + session.flush.assert_called_once() + session.commit.assert_called_once() + mock_store_cls.return_value.add_documents.assert_called_once() + assert mock_vector_cls.return_value.create.call_count == 1 + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_uses_content_files_when_files_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + mock_files.assert_called_once() + + def test_index_raises_when_account_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + def test_format_preview_returns_parent_child_structure(self, processor: ParentChildIndexProcessor) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child-1", "child-2"])], + ) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ): + preview = processor.format_preview({"parent_child_chunks": []}) + + assert preview["chunk_structure"] == "hierarchical_model" + assert preview["parent_mode"] == ParentMode.PARAGRAPH + assert preview["total_segments"] == 1 + + def test_generate_summary_preview_sets_summaries(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ): + result = processor.generate_summary_preview( + "tenant-1", preview_texts, {"enable": True}, doc_language="English" + ) + + assert all(item.summary == "summary" for item in result) + + def test_generate_summary_preview_raises_when_worker_fails(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + side_effect=RuntimeError("summary failed"), + ): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + def test_generate_summary_preview_falls_back_without_flask_context( + self, processor: ParentChildIndexProcessor + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ), + ): + result = processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_handles_timeout( + self, processor: ParentChildIndexProcessor, fake_executor_cls: type + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + future.cancel.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py new file mode 100644 index 0000000000..8596647ef3 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -0,0 +1,382 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pandas as pd +import pytest +from werkzeug.datastructures import FileStorage + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ImmediateThread: + def __init__(self, target, args=(), kwargs=None): + self._target = target + self._args = args + self._kwargs = kwargs or {} + + def start(self) -> None: + self._target(*self._args, **self._kwargs) + + def join(self) -> None: + return None + + +class TestQAIndexProcessor: + @pytest.fixture + def processor(self) -> QAIndexProcessor: + return QAIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def test_extract_forwards_automatic_flag(self, processor: QAIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.ExtractProcessor.extract") as mock_extract: + mock_extract.return_value = expected_docs + + docs = processor.extract(extract_setting, process_rule_mode="automatic") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_rejects_none_process_rule(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + def test_transform_rejects_missing_rules_key(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_preview_calls_formatter_once( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + split_node = Document(page_content=".question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content="Q1", metadata={"answer": "A1"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", return_value="clean text" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform( + [document], + process_rule=process_rule, + preview=True, + tenant_id="tenant-1", + doc_language="English", + ) + + assert len(result) == 1 + assert result[0].metadata["answer"] == "A1" + mock_format.assert_called_once() + + def test_transform_non_preview_uses_thread_batches( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + documents = [ + Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), + Document(page_content="doc-2", metadata={"document_id": "doc-2", "dataset_id": "dataset-1"}), + ] + split_node = Document(page_content="question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content=f"Q-{document_node.page_content}", metadata={"answer": "A"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + patch( + "core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread + ), + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1") + + assert len(result) == 2 + assert mock_format.call_count == 2 + + def test_format_by_template_validates_file_type(self, processor: QAIndexProcessor) -> None: + not_csv_file = Mock(spec=FileStorage) + not_csv_file.filename = "qa.txt" + + with pytest.raises(ValueError, match="Only CSV files"): + processor.format_by_template(not_csv_file) + + def test_format_by_template_parses_csv_rows(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + dataframe = pd.DataFrame([["Q1", "A1"], ["Q2", "A2"]]) + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=dataframe): + docs = processor.format_by_template(csv_file) + + assert [doc.page_content for doc in docs] == ["Q1", "Q2"] + assert [doc.metadata["answer"] for doc in docs] == ["A1", "A2"] + + def test_format_by_template_raises_on_empty_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=pd.DataFrame()): + with pytest.raises(ValueError, match="empty"): + processor.format_by_template(csv_file) + + def test_format_by_template_raises_on_invalid_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch( + "core.rag.index_processor.processor.qa_index_processor.pd.read_csv", side_effect=Exception("bad csv") + ): + with pytest.raises(ValueError, match="bad csv"): + processor.format_by_template(csv_file) + + def test_load_creates_vectors_for_high_quality_dataset(self, processor: QAIndexProcessor, dataset: Mock) -> None: + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + processor.load(dataset, docs) + + mock_vector_cls.assert_not_called() + + def test_clean_handles_summary_deletion_and_vector_cleanup( + self, processor: QAIndexProcessor, dataset: Mock + ) -> None: + mock_segment = SimpleNamespace(id="seg-1") + mock_query = Mock() + mock_query.filter.return_value.all.return_value = [mock_segment] + mock_session = Mock() + mock_session.query.return_value = mock_query + session_context = MagicMock() + session_context.__enter__.return_value = mock_session + session_context.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.session_factory.create_session", + return_value=session_context, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_handles_dataset_wide_cleanup(self, processor: QAIndexProcessor, dataset: Mock) -> None: + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + vector.delete.assert_called_once() + + def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None: + result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9) + result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1) + + with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: + mock_retrieve.return_value = [result_ok, result_low] + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].page_content == "accepted" + assert docs[0].metadata["score"] == 0.9 + + def test_index_adds_documents_and_vectors_for_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + qa_chunks = SimpleNamespace( + qa_chunks=[ + SimpleNamespace(question="Q1", answer="A1"), + SimpleNamespace(question="Q2", answer="A2"), + ] + ) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore") as mock_store_cls, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + + def test_index_requires_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore"), + ): + with pytest.raises(ValueError, match="must be high quality"): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + def test_format_preview_returns_qa_preview(self, processor: QAIndexProcessor) -> None: + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ): + preview = processor.format_preview({"qa_chunks": []}) + + assert preview["chunk_structure"] == "qa_model" + assert preview["total_segments"] == 1 + assert preview["qa_preview"] == [{"question": "Q1", "answer": "A1"}] + + def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None: + preview_items = [PreviewDetail(content="Q1")] + assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items + + def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + blank_document = Document(page_content=" ", metadata={}) + + processor._format_qa_document(fake_flask_app, "tenant-1", blank_document, all_qa_documents, "English") + + assert all_qa_documents == [] + + def test_format_qa_document_builds_question_answer_documents( + self, processor: QAIndexProcessor, fake_flask_app + ) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + return_value="Q1: What is this?\nA1: A test.\nQ2: Why?\nA2: Coverage.", + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert len(all_qa_documents) == 2 + assert all_qa_documents[0].page_content == "What is this?" + assert all_qa_documents[0].metadata["answer"] == "A test." + assert all_qa_documents[1].metadata["answer"] == "Coverage." + + def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + side_effect=RuntimeError("llm failure"), + ), + patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger, + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert all_qa_documents == [] + mock_logger.exception.assert_called_once_with("Failed to format qa document") + + def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None: + parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n") + + assert parsed == [{"question": "First?", "answer": "One."}, {"question": "Second?", "answer": "Two."}] diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py new file mode 100644 index 0000000000..b31bb6eea7 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -0,0 +1,291 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import httpx +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ForwardingBaseIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting, **kwargs): + return super().extract(extract_setting, **kwargs) + + def transform(self, documents, current_user=None, **kwargs): + return super().transform(documents, current_user=current_user, **kwargs) + + def generate_summary_preview(self, tenant_id, preview_texts, summary_index_setting, doc_language=None): + return super().generate_summary_preview( + tenant_id=tenant_id, + preview_texts=preview_texts, + summary_index_setting=summary_index_setting, + doc_language=doc_language, + ) + + def load(self, dataset, documents, multimodal_documents=None, with_keywords=True, **kwargs): + return super().load( + dataset=dataset, + documents=documents, + multimodal_documents=multimodal_documents, + with_keywords=with_keywords, + **kwargs, + ) + + def clean(self, dataset, node_ids, with_keywords=True, **kwargs): + return super().clean(dataset=dataset, node_ids=node_ids, with_keywords=with_keywords, **kwargs) + + def index(self, dataset, document, chunks): + return super().index(dataset=dataset, document=document, chunks=chunks) + + def format_preview(self, chunks): + return super().format_preview(chunks) + + def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model): + return super().retrieve( + retrieval_method=retrieval_method, + query=query, + dataset=dataset, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + + +class TestBaseIndexProcessor: + @pytest.fixture + def processor(self) -> _ForwardingBaseIndexProcessor: + return _ForwardingBaseIndexProcessor() + + def test_abstract_methods_raise_not_implemented(self, processor: _ForwardingBaseIndexProcessor) -> None: + with pytest.raises(NotImplementedError): + processor.extract(Mock()) + with pytest.raises(NotImplementedError): + processor.transform([]) + with pytest.raises(NotImplementedError): + processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {}) + with pytest.raises(NotImplementedError): + processor.load(Mock(), []) + with pytest.raises(NotImplementedError): + processor.clean(Mock(), None) + with pytest.raises(NotImplementedError): + processor.index(Mock(), Mock(), {}) + with pytest.raises(NotImplementedError): + processor.format_preview([]) + with pytest.raises(NotImplementedError): + processor.retrieve("semantic_search", "q", Mock(), 3, 0.5, {}) + + def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None: + with patch( + "core.rag.index_processor.index_processor_base.dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH", 1000 + ): + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 49, 0, "", None) + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 1001, 0, "", None) + + def test_get_splitter_custom_mode_uses_fixed_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + fixed_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.FixedRecursiveCharacterTextSplitter.from_encoder", + return_value=fixed_splitter, + ) as mock_fixed: + splitter = processor._get_splitter("hierarchical", 120, 10, "\\n\\n", None) + + assert splitter is fixed_splitter + assert mock_fixed.call_args.kwargs["fixed_separator"] == "\n\n" + assert mock_fixed.call_args.kwargs["chunk_size"] == 120 + + def test_get_splitter_automatic_mode_uses_enhance_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + auto_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.EnhanceRecursiveCharacterTextSplitter.from_encoder", + return_value=auto_splitter, + ) as mock_enhance: + splitter = processor._get_splitter("automatic", 0, 0, "", None) + + assert splitter is auto_splitter + assert "chunk_size" in mock_enhance.call_args.kwargs + + def test_extract_markdown_images(self, processor: _ForwardingBaseIndexProcessor) -> None: + markdown = "text  and " + images = processor._extract_markdown_images(markdown) + assert images == ["https://a/img.png", "/files/123/file-preview"] + + def test_get_content_files_without_images_returns_empty(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="no image markdown", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + assert processor._get_content_files(document) == [] + + def test_get_content_files_handles_all_sources_and_duplicates( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = [ + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb/file-preview", + "/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", + "https://example.com/remote.png?x=1", + ] + upload_a = SimpleNamespace(id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", name="a.png") + upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png") + upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png") + upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png") + db_query = Mock() + db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch.object(processor, "_download_tool_file", return_value="tool-upload-id") as mock_tool_download, + patch.object(processor, "_download_image", return_value="remote-upload-id") as mock_image_download, + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document, current_user=Mock()) + + assert len(files) == 5 + assert all(isinstance(file, AttachmentDocument) for file in files) + assert files[0].metadata["doc_type"] == DocType.IMAGE + assert files[0].metadata["document_id"] == "doc-1" + assert files[0].metadata["dataset_id"] == "ds-1" + assert files[0].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + assert files[1].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_tool_download.assert_called_once() + mock_image_download.assert_called_once() + + def test_get_content_files_skips_tool_and_remote_download_without_user( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", "https://example.com/remote.png"] + + with patch.object(processor, "_extract_markdown_images", return_value=images): + files = processor._get_content_files(document, current_user=None) + + assert files == [] + + def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"] + db_query = Mock() + db_query.where.return_value.all.return_value = [] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document) + + assert files == [] + + def test_download_image_success_with_filename_from_content_disposition( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + response = Mock() + response.headers = { + "Content-Length": "4", + "content-disposition": "attachment; filename=test-image.png", + "content-type": "image/png", + } + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"data"] + upload_result = SimpleNamespace(id="upload-id") + + mock_db = Mock() + mock_db.engine = Mock() + + with ( + patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response), + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + upload_id = processor._download_image("https://example.com/test.png", current_user=Mock()) + + assert upload_id == "upload-id" + mock_file_service.return_value.upload_file.assert_called_once() + + def test_download_image_validates_size_and_empty_content(self, processor: _ForwardingBaseIndexProcessor) -> None: + too_large = Mock() + too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"} + too_large.raise_for_status.return_value = None + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large): + assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None + + empty = Mock() + empty.headers = {"Content-Length": "0", "content-type": "image/png"} + empty.raise_for_status.return_value = None + empty.iter_bytes.return_value = [] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty): + assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None + + def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None: + response = Mock() + response.headers = {"content-type": "image/png"} + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response): + assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None + + def test_download_image_handles_timeout_request_and_unexpected_errors( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + request = httpx.Request("GET", "https://example.com/image.png") + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.TimeoutException("timeout"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.RequestError("bad request", request=request), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=RuntimeError("unexpected"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + db_query = Mock() + db_query.where.return_value.first.return_value = None + db_session = Mock() + db_session.query.return_value = db_query + + with patch("core.rag.index_processor.index_processor_base.db.session", db_session): + assert processor._download_tool_file("tool-id", current_user=Mock()) is None + + def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png") + db_query = Mock() + db_query.where.return_value.first.return_value = tool_file + db_session = Mock() + db_session.query.return_value = db_query + mock_db = Mock() + mock_db.session = db_session + mock_db.engine = Mock() + upload_result = SimpleNamespace(id="upload-id") + + with ( + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("core.rag.index_processor.index_processor_base.storage.load_once", return_value=b"blob") as mock_load, + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + result = processor._download_tool_file("tool-id", current_user=Mock()) + + assert result == "upload-id" + mock_load.assert_called_once_with("k1") + mock_file_service.return_value.upload_file.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py new file mode 100644 index 0000000000..0fc666dbbf --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py @@ -0,0 +1,42 @@ +import pytest + +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor + + +class TestIndexProcessorFactory: + def test_requires_index_type(self) -> None: + factory = IndexProcessorFactory(index_type=None) + + with pytest.raises(ValueError, match="Index type must be specified"): + factory.init_index_processor() + + def test_builds_paragraph_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARAGRAPH_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParagraphIndexProcessor) + + def test_builds_qa_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.QA_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, QAIndexProcessor) + + def test_builds_parent_child_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARENT_CHILD_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParentChildIndexProcessor) + + def test_rejects_unsupported_index_type(self) -> None: + factory = IndexProcessorFactory(index_type="unsupported") + + with pytest.raises(ValueError, match="is not supported"): + factory.init_index_processor() diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 0e53482c51..b150d677f1 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -12,13 +12,18 @@ All tests use mocking to avoid external dependencies and ensure fast, reliable e Tests follow the Arrange-Act-Assert pattern for clarity. """ +from operator import itemgetter +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest from core.model_manager import ModelInstance +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode @@ -26,7 +31,7 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -def create_mock_model_instance(): +def create_mock_model_instance() -> ModelInstance: """Create a properly configured mock ModelInstance for reranking tests.""" mock_instance = Mock(spec=ModelInstance) # Setup provider_model_bundle chain for check_model_support_vision @@ -59,14 +64,7 @@ class TestRerankModelRunner: @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" - mock_instance = Mock(spec=ModelInstance) - # Setup provider_model_bundle chain for check_model_support_vision - mock_instance.provider_model_bundle = Mock() - mock_instance.provider_model_bundle.configuration = Mock() - mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" - mock_instance.provider = "test-provider" - mock_instance.model_name = "test-model" - return mock_instance + return create_mock_model_instance() @pytest.fixture def rerank_runner(self, mock_model_instance): @@ -382,6 +380,206 @@ class TestRerankModelRunner: assert call_kwargs["user"] == "user123" +class _ForwardingBaseRerankRunner(BaseRerankRunner): + def run( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> list[Document]: + return super().run( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=top_n, + user=user, + query_type=query_type, + ) + + +class TestBaseRerankRunner: + def test_run_raises_not_implemented(self): + runner = _ForwardingBaseRerankRunner() + + with pytest.raises(NotImplementedError): + runner.run(query="python", documents=[]) + + +class TestRerankModelRunnerMultimodal: + @pytest.fixture + def mock_model_instance(self): + return create_mock_model_instance() + + @pytest.fixture + def rerank_runner(self, mock_model_instance): + return RerankModelRunner(rerank_model_instance=mock_model_instance) + + def test_run_returns_original_documents_for_non_text_query_without_vision_support( + self, rerank_runner, mock_model_instance + ): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), + ] + + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) + + assert result == documents + mock_model_instance.invoke_rerank.assert_not_called() + + def test_run_uses_multimodal_path_when_vision_support_is_enabled(self, rerank_runner): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1", "source": "wiki"}, provider="dify"), + ] + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="doc", score=0.88)], + ) + + with ( + patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch.object( + rerank_runner, + "fetch_multimodal_rerank", + return_value=(rerank_result, documents), + ) as mock_multimodal, + ): + mock_mm.return_value.check_model_support_vision.return_value = True + result = rerank_runner.run(query="python", documents=documents, query_type=QueryType.TEXT_QUERY) + + mock_multimodal.assert_called_once() + assert len(result) == 1 + assert result[0].metadata["score"] == 0.88 + + def test_fetch_multimodal_rerank_builds_docs_and_calls_text_rerank(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-1", "doc_type": DocType.IMAGE}, + provider="dify", + ) + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + external_doc = Document( + page_content="external-content", + metadata={}, + provider="external", + ) + query = Mock() + query.where.return_value.first.return_value = SimpleNamespace(key="image-key") + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once, + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc, text_doc, external_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc, text_doc, external_doc, external_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert len(unique_documents) == 3 + mock_load_once.assert_called_once_with("image-key") + text_rerank_call_args = mock_text_rerank.call_args.args + assert len(text_rerank_call_args[1]) == 3 + + def test_fetch_multimodal_rerank_skips_missing_image_upload(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE}, + provider="dify", + ) + query = Mock() + query.where.return_value.first.return_value = None + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [image_doc] + docs_arg = mock_text_rerank.call_args.args[1] + assert len(docs_arg) == 1 + + def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance): + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + query_chain = Mock() + query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key") + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="text-content", score=0.77)], + ) + mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="query-upload-id", + documents=[text_doc], + score_threshold=0.2, + top_n=2, + user="user-1", + query_type=QueryType.IMAGE_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [text_doc] + invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs + assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE + assert invoke_kwargs["docs"][0]["content"] == "text-content" + assert invoke_kwargs["user"] == "user-1" + + def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): + query_chain = Mock() + query_chain.where.return_value.first.return_value = None + + with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain): + with pytest.raises(ValueError, match="Upload file not found for query"): + rerank_runner.fetch_multimodal_rerank( + query="missing-upload-id", + documents=[], + query_type=QueryType.IMAGE_QUERY, + ) + + def test_fetch_multimodal_rerank_rejects_unsupported_query_type(self, rerank_runner): + with pytest.raises(ValueError, match="is not supported"): + rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[], + query_type="unsupported_query_type", + ) + + class TestWeightRerankRunner: """Unit tests for WeightRerankRunner. @@ -512,34 +710,39 @@ class TestWeightRerankRunner: - TF-IDF scores are calculated correctly - Cosine similarity is computed for keyword vectors """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - - # Mock keyword extraction with specific keywords + keyword_map = { + "python programming": ["python", "programming"], + "Python is a programming language": ["python", "programming", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.side_effect = [ - ["python", "programming"], # query - ["python", "programming", "language"], # doc1 - ["javascript", "web"], # doc2 - ["java", "programming"], # doc3 - ] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("python programming", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "python programming", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="python programming", documents=sample_documents_with_vectors) - # Assert: Keywords are extracted and scores are calculated - assert len(result) == 3 - # Document 1 should have highest keyword score (matches both query terms) - # Document 3 should have medium score (matches one term) - # Document 2 should have lowest score (matches no terms) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_vector_score_calculation( self, @@ -556,30 +759,42 @@ class TestWeightRerankRunner: - Cosine similarity is calculated with document vectors - Vector scores are properly normalized """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test query": ["test"], + "Python is a programming language": ["python"], + "JavaScript for web development": ["javascript"], + "Java object-oriented programming": ["java"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding model mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance - # Mock cache embedding with specific query vector mock_cache_instance = MagicMock() query_vector = [0.2, 0.3, 0.4, 0.5] mock_cache_instance.embed_query.return_value = query_vector mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test query", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test query", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test query", documents=sample_documents_with_vectors) - # Assert: Vector scores are calculated - assert len(result) == 3 - # Verify cosine similarity was computed (doc2 vector is closest to query vector) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_score_threshold_filtering_weighted( self, @@ -742,28 +957,40 @@ class TestWeightRerankRunner: - Keyword weight (0.4) is applied to keyword scores - Combined score is the sum of weighted components """ - # Arrange: Create runner with known weights runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Python is a programming language": ["python", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test", documents=sample_documents_with_vectors) - # Assert: Scores are combined with weights - # Score = 0.6 * vector_score + 0.4 * keyword_score - assert len(result) == 3 - assert all("score" in doc.metadata for doc in result) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_existing_vector_score_in_metadata( self, @@ -778,7 +1005,6 @@ class TestWeightRerankRunner: - If document already has a score in metadata, it's used - Cosine similarity calculation is skipped for such documents """ - # Arrange: Documents with pre-existing scores documents = [ Document( page_content="Content with existing score", @@ -790,24 +1016,29 @@ class TestWeightRerankRunner: runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Content with existing score": ["test"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", documents) + vector_scores = runner._calculate_cosine("tenant123", "test", documents, weights_config.vector_setting) + expected_score = 0.6 * vector_scores[0] + 0.4 * query_scores[0] + result = runner.run(query="test", documents=documents) - # Assert: Existing score is used in calculation assert len(result) == 1 - # The final score should incorporate the existing score (0.95) with vector weight (0.6) + assert result[0].metadata["doc_id"] == "doc1" + assert result[0].metadata["score"] == pytest.approx(expected_score, rel=1e-6) class TestRerankRunnerFactory: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index ca08cb0591..b90c4935af 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,80 +1,41 @@ -""" -Unit tests for dataset retrieval functionality. - -This module provides comprehensive test coverage for the RetrievalService class, -which is responsible for retrieving relevant documents from datasets using various -search strategies. - -Core Retrieval Mechanisms Tested: -================================== -1. **Vector Search (Semantic Search)** - - Uses embedding vectors to find semantically similar documents - - Supports score thresholds and top-k limiting - - Can filter by document IDs and metadata - -2. **Keyword Search** - - Traditional text-based search using keyword matching - - Handles special characters and query escaping - - Supports document filtering - -3. **Full-Text Search** - - BM25-based full-text search for text matching - - Used in hybrid search scenarios - -4. **Hybrid Search** - - Combines vector and full-text search results - - Implements deduplication to avoid duplicate chunks - - Uses DataPostProcessor for score merging with configurable weights - -5. **Score Merging Algorithms** - - Deduplication based on doc_id - - Retains higher-scoring duplicates - - Supports weighted score combination - -6. **Metadata Filtering** - - Filters documents based on metadata conditions - - Supports document ID filtering - -Test Architecture: -================== -- **Fixtures**: Provide reusable mock objects (datasets, documents, Flask app) -- **Mocking Strategy**: Mock at the method level (embedding_search, keyword_search, etc.) - rather than at the class level to properly simulate the ThreadPoolExecutor behavior -- **Pattern**: All tests follow Arrange-Act-Assert (AAA) pattern -- **Isolation**: Each test is independent and doesn't rely on external state - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py -v - - # Run a specific test class - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::TestRetrievalService -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::\ -TestRetrievalService::test_vector_search_basic -v - -Notes: -====== -- The RetrievalService uses ThreadPoolExecutor for concurrent search operations -- Tests mock the individual search methods to avoid threading complexity -- All mocked search methods modify the all_documents list in-place -- Score thresholds and top-k limits are enforced by the search methods -""" - +import threading +from contextlib import contextmanager, nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest +from flask import Flask, current_app +from sqlalchemy import column +from core.app.app_config.entities import ( + Condition as AppCondition, +) +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, +) +from core.app.app_config.entities import ( + MetadataFilteringCondition as AppMetadataFilteringCondition, +) +from core.app.app_config.entities import ( + ModelConfig as AppModelConfig, +) +from core.app.app_config.entities import ModelConfig as WorkflowModelConfig +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.agent_entities import PlanningStrategy +from core.entities.model_entities import ModelStatus from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.nodes.knowledge_retrieval import exc +from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest from models.dataset import Dataset # ==================== Helper Functions ==================== @@ -2013,3 +1974,3091 @@ class TestDocumentModel: assert doc1 == doc2 assert doc1 != doc3 + + +# ==================== Helper Functions ==================== + + +def create_mock_dataset_methods( + dataset_id: str | None = None, + tenant_id: str | None = None, + provider: str = "dify", + indexing_technique: str = "high_quality", + available_document_count: int = 10, +) -> Mock: + """ + Create a mock Dataset object for testing. + + Args: + dataset_id: Unique identifier for the dataset + tenant_id: Tenant ID for the dataset + provider: Provider type ("dify" or "external") + indexing_technique: Indexing technique ("high_quality" or "economy") + available_document_count: Number of available documents + + Returns: + Mock: A properly configured Dataset mock + """ + dataset = Mock(spec=Dataset) + dataset.id = dataset_id or str(uuid4()) + dataset.tenant_id = tenant_id or str(uuid4()) + dataset.name = "test_dataset" + dataset.provider = provider + dataset.indexing_technique = indexing_technique + dataset.available_document_count = available_document_count + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.retrieval_model = { + "search_method": "semantic_search", + "reranking_enable": False, + "top_k": 4, + "score_threshold_enabled": False, + } + return dataset + + +def create_mock_document_methods( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +# ==================== Test _check_knowledge_rate_limit ==================== + + +class TestCheckKnowledgeRateLimit: + """ + Test suite for _check_knowledge_rate_limit method. + + The _check_knowledge_rate_limit method validates whether a tenant has + exceeded their knowledge retrieval rate limit. This is important for: + - Preventing abuse of the knowledge retrieval system + - Enforcing subscription plan limits + - Tracking usage for billing purposes + + Test Cases: + ============ + 1. Rate limit disabled - no exception raised + 2. Rate limit enabled but not exceeded - no exception raised + 3. Rate limit enabled and exceeded - RateLimitExceededError raised + 4. Redis operations are performed correctly + 5. RateLimitLog is created when limit is exceeded + """ + + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service): + """ + Test that when rate limit is disabled, no exception is raised. + + This test verifies the behavior when the tenant's subscription + does not have rate limiting enabled. + + Verifies: + - FeatureService.get_knowledge_rate_limit is called + - No Redis operations are performed + - No exception is raised + - Retrieval proceeds normally + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit disabled + mock_limit = Mock() + mock_limit.enabled = False + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify FeatureService was called + mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id) + + # Verify no Redis operations were performed + assert not mock_redis.zadd.called + assert not mock_redis.zremrangebyscore.called + assert not mock_redis.zcard.called + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory): + """ + Test that when rate limit is enabled but not exceeded, no exception is raised. + + This test simulates a tenant making requests within their rate limit. + The Redis sorted set stores timestamps of recent requests, and old + requests (older than 60 seconds) are removed. + + Verifies: + - Redis zadd is called to track the request + - Redis zremrangebyscore removes old entries + - Redis zcard returns count within limit + - No exception is raised + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 # Current time in milliseconds + mock_time.time.return_value = current_time / 1000 # Return seconds + mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds + + # Mock Redis operations + # zcard returns 50 (within limit of 100) + mock_redis.zcard.return_value = 50 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify Redis operations + expected_key = f"rate_limit_{tenant_id}" + mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time}) + mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000) + mock_redis.zcard.assert_called_once_with(expected_key) + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_exceeded_raises_exception( + self, mock_time, mock_redis, mock_feature_service, mock_session_factory + ): + """ + Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised. + + This test simulates a tenant exceeding their rate limit. When the count + of recent requests exceeds the limit, an exception should be raised and + a RateLimitLog should be created. + + Verifies: + - Redis zcard returns count exceeding limit + - RateLimitExceededError is raised with correct message + - RateLimitLog is created in database + - Session operations are performed correctly + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 + mock_time.time.return_value = current_time / 1000 + + # Mock Redis operations - return count exceeding limit + mock_redis.zcard.return_value = 150 # Exceeds limit of 100 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert + with pytest.raises(exc.RateLimitExceededError) as exc_info: + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify exception message + assert "knowledge base request rate limit" in str(exc_info.value) + + # Verify RateLimitLog was created + mock_session.add.assert_called_once() + added_log = mock_session.add.call_args[0][0] + assert added_log.tenant_id == tenant_id + assert added_log.subscription_plan == "professional" + assert added_log.operation == "knowledge" + + +# ==================== Test _get_available_datasets ==================== + + +class TestGetAvailableDatasets: + """ + Test suite for _get_available_datasets method. + + The _get_available_datasets method retrieves datasets that are available + for retrieval. A dataset is considered available if: + - It belongs to the specified tenant + - It's in the list of requested dataset_ids + - It has at least one completed, enabled, non-archived document OR + - It's an external provider dataset + + Note: Due to SQLAlchemy subquery complexity, full testing is done in + integration tests. Unit tests here verify basic behavior. + """ + + def test_method_exists_and_has_correct_signature(self): + """ + Test that the method exists and has the correct signature. + + Verifies: + - Method exists on DatasetRetrieval class + - Accepts tenant_id and dataset_ids parameters + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + + # Assert - method exists + assert hasattr(dataset_retrieval, "_get_available_datasets") + # Assert - method is callable + assert callable(dataset_retrieval._get_available_datasets) + + +# ==================== Test knowledge_retrieval ==================== + + +class TestDatasetRetrievalKnowledgeRetrieval: + """ + Test suite for knowledge_retrieval method. + + The knowledge_retrieval method is the main entry point for retrieving + knowledge from datasets. It orchestrates the entire retrieval process: + 1. Checks rate limits + 2. Gets available datasets + 3. Applies metadata filtering if enabled + 4. Performs retrieval (single or multiple mode) + 5. Formats and returns results + + Test Cases: + ============ + 1. Single mode retrieval + 2. Multiple mode retrieval + 3. Metadata filtering disabled + 4. Metadata filtering automatic + 5. Metadata filtering manual + 6. External documents handling + 7. Dify documents handling + 8. Empty results handling + 9. Rate limit exceeded + 10. No available datasets + """ + + def test_knowledge_retrieval_single_mode_basic(self): + """ + Test knowledge_retrieval in single retrieval mode - basic check. + + Note: Full single mode testing requires complex model mocking and + is better suited for integration tests. This test verifies the + method accepts single mode requests. + + Verifies: + - Method can accept single mode request + - Request parameters are correctly structured + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + model_mode="chat", + completion_params={"temperature": 0.7}, + ) + + # Assert - request is properly structured + assert request.retrieval_mode == "single" + assert request.model_provider == "openai" + assert request.model_name == "gpt-4" + assert request.model_mode == "chat" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor): + """ + Test knowledge_retrieval in multiple retrieval mode. + + In multiple mode, retrieval is performed across all datasets and + results are combined and reranked. + + Verifies: + - Rate limit is checked + - Available datasets are retrieved + - Multiple retrieval is performed + - Results are combined and reranked + - Results are formatted correctly + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id1 = str(uuid4()) + dataset_id2 = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id1, dataset_id2], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + score_threshold=0.7, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets + mock_dataset1 = create_mock_dataset_methods(dataset_id=dataset_id1, tenant_id=tenant_id) + mock_dataset2 = create_mock_dataset_methods(dataset_id=dataset_id2, tenant_id=tenant_id) + with patch.object( + dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2] + ): + # Mock get_metadata_filter_condition + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return documents + doc1 = create_mock_document_methods("Python is great", "doc1", score=0.9) + doc2 = create_mock_document_methods("Python is awesome", "doc2", score=0.8) + with patch.object( + dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2] + ) as mock_multiple_retrieve: + # Mock format_retrieval_documents + mock_record = Mock() + mock_record.segment = Mock() + mock_record.segment.dataset_id = dataset_id1 + mock_record.segment.document_id = str(uuid4()) + mock_record.segment.index_node_hash = "hash123" + mock_record.segment.hit_count = 5 + mock_record.segment.word_count = 100 + mock_record.segment.position = 1 + mock_record.segment.get_sign_content.return_value = "Python is great" + mock_record.segment.answer = None + mock_record.score = 0.9 + mock_record.child_chunks = [] + mock_record.summary = None + mock_record.files = None + + mock_retrieval_service = Mock() + mock_retrieval_service.format_retrieval_documents.return_value = [mock_record] + + with patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService", + return_value=mock_retrieval_service, + ): + # Mock database queries + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + mock_dataset_from_db = Mock() + mock_dataset_from_db.id = dataset_id1 + mock_dataset_from_db.name = "test_dataset" + + mock_document = Mock() + mock_document.id = str(uuid4()) + mock_document.name = "test_doc" + mock_document.data_source_type = "upload_file" + mock_document.doc_metadata = {} + + mock_session.query.return_value.filter.return_value.all.return_value = [ + mock_dataset_from_db + ] + mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( + [mock_dataset_from_db, mock_document] + ) + + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + mock_multiple_retrieve.assert_called_once() + + def test_knowledge_retrieval_metadata_filtering_disabled(self): + """ + Test knowledge_retrieval with metadata filtering disabled. + + When metadata filtering is disabled, get_metadata_filter_condition is + NOT called (the method checks metadata_filtering_mode != "disabled"). + + Verifies: + - get_metadata_filter_condition is NOT called when mode is "disabled" + - Retrieval proceeds without metadata filters + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + metadata_filtering_mode="disabled", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + # Mock get_metadata_filter_condition - should NOT be called when disabled + with patch.object( + dataset_retrieval, + "get_metadata_filter_condition", + return_value=(None, None), + ) as mock_get_metadata: + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + # get_metadata_filter_condition should NOT be called when mode is "disabled" + mock_get_metadata.assert_not_called() + + def test_knowledge_retrieval_with_external_documents(self): + """ + Test knowledge_retrieval with external documents. + + External documents come from external knowledge bases and should + be formatted differently than Dify documents. + + Verifies: + - External documents are handled correctly + - Provider is set to "external" + - Metadata includes external-specific fields + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id, provider="external") + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Create external document + external_doc = create_mock_document_methods( + "External knowledge", + "doc1", + score=0.9, + provider="external", + additional_metadata={ + "dataset_id": dataset_id, + "dataset_name": "external_kb", + "document_id": "ext_doc1", + "title": "External Document", + }, + ) + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + if result: + assert result[0].metadata.data_source_type == "external" + + def test_knowledge_retrieval_empty_results(self): + """ + Test knowledge_retrieval when no documents are found. + + Verifies: + - Empty list is returned + - No errors are raised + - All dependencies are still called + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return empty list + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_rate_limit_exceeded(self): + """ + Test knowledge_retrieval when rate limit is exceeded. + + Verifies: + - RateLimitExceededError is raised + - No further processing occurs + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit to raise exception + with patch.object( + dataset_retrieval, + "_check_knowledge_rate_limit", + side_effect=exc.RateLimitExceededError("Rate limit exceeded"), + ): + # Act & Assert + with pytest.raises(exc.RateLimitExceededError): + dataset_retrieval.knowledge_retrieval(request) + + def test_knowledge_retrieval_no_available_datasets(self): + """ + Test knowledge_retrieval when no datasets are available. + + Verifies: + - Empty list is returned + - No retrieval is attempted + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets to return empty list + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self): + """ + Test that knowledge_retrieval processes multiple documents with different scores. + + Note: Full sorting and position testing requires complex SQLAlchemy mocking + which is better suited for integration tests. This test verifies documents + with different scores can be created and have their metadata. + + Verifies: + - Documents can be created with different scores + - Score metadata is properly set + """ + # Create documents with different scores + doc1 = create_mock_document_methods("Low score", "doc1", score=0.6) + doc2 = create_mock_document_methods("High score", "doc2", score=0.95) + doc3 = create_mock_document_methods("Medium score", "doc3", score=0.8) + + # Assert - each document has the correct score + assert doc1.metadata["score"] == 0.6 + assert doc2.metadata["score"] == 0.95 + assert doc3.metadata["score"] == 0.8 + + # Assert - documents are correctly sorted (not the retrieval result, just the list) + unsorted = [doc1, doc2, doc3] + sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True) + assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6] + + +class TestProcessMetadataFilterFunc: + """ + Comprehensive test suite for process_metadata_filter_func method. + + This test class validates all metadata filtering conditions supported by + the DatasetRetrieval class, including string operations, numeric comparisons, + null checks, and list operations. + + Method Signature: + ================== + def process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + ) -> list: + + The method builds SQLAlchemy filter expressions by: + 1. Validating value is not None (except for empty/not empty conditions) + 2. Using DatasetDocument.doc_metadata JSON field operations + 3. Adding appropriate SQLAlchemy expressions to the filters list + 4. Returning the updated filters list + + Mocking Strategy: + ================== + - Mock DatasetDocument.doc_metadata to avoid database dependencies + - Verify filter expressions are created correctly + - Test with various data types (str, int, float, list) + """ + + @pytest.fixture + def retrieval(self): + """ + Create a DatasetRetrieval instance for testing. + + Returns: + DatasetRetrieval: Instance to test process_metadata_filter_func + """ + return DatasetRetrieval() + + @pytest.fixture + def mock_doc_metadata(self): + """ + Mock the DatasetDocument.doc_metadata JSON field. + + The method uses DatasetDocument.doc_metadata[metadata_name] to access + JSON fields. We mock this to avoid database dependencies. + + Returns: + Mock: Mocked doc_metadata attribute + """ + mock_metadata_field = MagicMock() + + # Create mock for string access + mock_string_access = MagicMock() + mock_string_access.like = MagicMock() + mock_string_access.notlike = MagicMock() + mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_string_access.in_ = MagicMock(return_value=MagicMock()) + + # Create mock for float access (for numeric comparisons) + mock_float_access = MagicMock() + mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__le__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) + + # Create mock for null checks + mock_null_access = MagicMock() + mock_null_access.is_ = MagicMock(return_value=MagicMock()) + mock_null_access.isnot = MagicMock(return_value=MagicMock()) + + # Setup __getitem__ to return appropriate mock based on usage + def getitem_side_effect(name): + if name in ["author", "title", "category"]: + return mock_string_access + elif name in ["year", "price", "rating"]: + return mock_float_access + else: + return mock_string_access + + mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) + mock_metadata_field.as_string.return_value = mock_string_access + mock_metadata_field.as_float.return_value = mock_float_access + mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ + mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot + + return mock_metadata_field + + # ==================== String Condition Tests ==================== + + def test_contains_condition_string_value(self, retrieval): + """ + Test 'contains' condition with string value. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value% syntax + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "John" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_contains_condition(self, retrieval): + """ + Test 'not contains' condition. + + Verifies: + - Filters list is populated with NOT LIKE expression + - Pattern matching uses %value% syntax with negation + """ + filters = [] + sequence = 0 + condition = "not contains" + metadata_name = "title" + value = "banned" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_start_with_condition(self, retrieval): + """ + Test 'start with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses value% syntax + """ + filters = [] + sequence = 0 + condition = "start with" + metadata_name = "category" + value = "tech" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_end_with_condition(self, retrieval): + """ + Test 'end with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value syntax + """ + filters = [] + sequence = 0 + condition = "end with" + metadata_name = "filename" + value = ".pdf" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Equality Condition Tests ==================== + + def test_is_condition_with_string_value(self, retrieval): + """ + Test 'is' (=) condition with string value. + + Verifies: + - Filters list is populated with equality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = "Jane Doe" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_equals_condition_with_string_value(self, retrieval): + """ + Test '=' condition with string value. + + Verifies: + - Same behavior as 'is' condition + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "=" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_int_value(self, retrieval): + """ + Test 'is' condition with integer value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_float_value(self, retrieval): + """ + Test 'is' condition with float value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "price" + value = 19.99 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_string_value(self, retrieval): + """ + Test 'is not' (≠) condition with string value. + + Verifies: + - Filters list is populated with inequality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "author" + value = "Unknown" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_equals_condition(self, retrieval): + """ + Test '≠' condition with string value. + + Verifies: + - Same behavior as 'is not' condition + - Inequality expression is used + """ + filters = [] + sequence = 0 + condition = "≠" + metadata_name = "category" + value = "archived" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_numeric_value(self, retrieval): + """ + Test 'is not' condition with numeric value. + + Verifies: + - Numeric inequality comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Null Condition Tests ==================== + + def test_empty_condition(self, retrieval): + """ + Test 'empty' condition (null check). + + Verifies: + - Filters list is populated with IS NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "empty" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_empty_condition(self, retrieval): + """ + Test 'not empty' condition (not null check). + + Verifies: + - Filters list is populated with IS NOT NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "not empty" + metadata_name = "description" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Numeric Comparison Tests ==================== + + def test_before_condition(self, retrieval): + """ + Test 'before' (<) condition. + + Verifies: + - Filters list is populated with less than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "before" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_condition(self, retrieval): + """ + Test '<' condition. + + Verifies: + - Same behavior as 'before' condition + - Less than expression is used + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "price" + value = 100.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_after_condition(self, retrieval): + """ + Test 'after' (>) condition. + + Verifies: + - Filters list is populated with greater than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "after" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_condition(self, retrieval): + """ + Test '>' condition. + + Verifies: + - Same behavior as 'after' condition + - Greater than expression is used + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≤' condition. + + Verifies: + - Filters list is populated with less than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≤" + metadata_name = "price" + value = 50.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_ascii(self, retrieval): + """ + Test '<=' condition. + + Verifies: + - Same behavior as '≤' condition + - Less than or equal expression is used + """ + filters = [] + sequence = 0 + condition = "<=" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≥' condition. + + Verifies: + - Filters list is populated with greater than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≥" + metadata_name = "rating" + value = 3.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_ascii(self, retrieval): + """ + Test '>=' condition. + + Verifies: + - Same behavior as '≥' condition + - Greater than or equal expression is used + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== List/In Condition Tests ==================== + + def test_in_condition_with_comma_separated_string(self, retrieval): + """ + Test 'in' condition with comma-separated string value. + + Verifies: + - String is split into list + - Whitespace is trimmed from each value + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "tech, science, AI " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_list_value(self, retrieval): + """ + Test 'in' condition with list value. + + Verifies: + - List is processed correctly + - None values are filtered out + - IN expression is created with valid values + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "tags" + value = ["python", "javascript", None, "golang"] + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_tuple_value(self, retrieval): + """ + Test 'in' condition with tuple value. + + Verifies: + - Tuple is processed like a list + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = ("tech", "science", "ai") + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_empty_string(self, retrieval): + """ + Test 'in' condition with empty string value. + + Verifies: + - Empty string results in literal(False) filter + - No valid values to match + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + # Verify it's a literal(False) expression + # This is a bit tricky to test without access to the actual expression + + def test_in_condition_with_only_whitespace(self, retrieval): + """ + Test 'in' condition with whitespace-only string value. + + Verifies: + - Whitespace-only string results in literal(False) filter + - All values are stripped and filtered out + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = " , , " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_single_string(self, retrieval): + """ + Test 'in' condition with single non-comma string. + + Verifies: + - Single string is treated as single-item list + - IN expression is created with one value + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Edge Case Tests ==================== + + def test_none_value_with_non_empty_condition(self, retrieval): + """ + Test None value with conditions that require value. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values (except empty/not empty) + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 # No filter added + + def test_none_value_with_equals_condition(self, retrieval): + """ + Test None value with 'is' (=) condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_none_value_with_numeric_condition(self, retrieval): + """ + Test None value with numeric comparison condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "year" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_existing_filters_preserved(self, retrieval): + """ + Test that existing filters are preserved. + + Verifies: + - Existing filters in the list are not removed + - New filters are appended to the list + """ + existing_filter = MagicMock() + filters = [existing_filter] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 2 + assert filters[0] == existing_filter + + def test_multiple_filters_accumulated(self, retrieval): + """ + Test multiple calls to accumulate filters. + + Verifies: + - Each call adds a new filter to the list + - All filters are preserved across calls + """ + filters = [] + + # First filter + retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) + assert len(filters) == 1 + + # Second filter + retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) + assert len(filters) == 2 + + # Third filter + retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) + assert len(filters) == 3 + + def test_unknown_condition(self, retrieval): + """ + Test unknown/unsupported condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for unknown conditions + """ + filters = [] + sequence = 0 + condition = "unknown_condition" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_empty_string_value_with_contains(self, retrieval): + """ + Test empty string value with 'contains' condition. + + Verifies: + - Filter is added even with empty string + - LIKE expression is created + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_special_characters_in_value(self, retrieval): + """ + Test special characters in value string. + + Verifies: + - Special characters are handled in value + - LIKE expression is created correctly + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "title" + value = "C++ & Python's features" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_zero_value_with_numeric_condition(self, retrieval): + """ + Test zero value with numeric comparison condition. + + Verifies: + - Zero is treated as valid value + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "price" + value = 0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_negative_value_with_numeric_condition(self, retrieval): + """ + Test negative value with numeric comparison condition. + + Verifies: + - Negative numbers are handled correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "temperature" + value = -10.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_float_value_with_integer_comparison(self, retrieval): + """ + Test float value with numeric comparison condition. + + Verifies: + - Float values work correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + +class TestKnowledgeRetrievalRegression: + @pytest.fixture + def mock_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = str(uuid4()) + dataset.tenant_id = str(uuid4()) + dataset.name = "test_dataset" + dataset.indexing_technique = "high_quality" + dataset.provider = "dify" + return dataset + + def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): + """ + Repro test for current bug: + reranking runs after `with flask_app.app_context():` exits. + `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, + so we must assert from that list (not from an outer try/except). + """ + dataset_retrieval = DatasetRetrieval() + flask_app = Flask(__name__) + tenant_id = str(uuid4()) + + # second dataset to ensure dataset_count > 1 reranking branch + secondary_dataset = Mock(spec=Dataset) + secondary_dataset.id = str(uuid4()) + secondary_dataset.provider = "dify" + secondary_dataset.indexing_technique = "high_quality" + + # retriever returns 1 doc into internal list (all_documents_item) + document = Document( + page_content="Context aware doc", + metadata={ + "doc_id": "doc1", + "score": 0.95, + "document_id": str(uuid4()), + "dataset_id": mock_dataset.id, + }, + provider="dify", + ) + + def fake_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.append(document) + + called = {"init": 0, "invoke": 0} + + class ContextRequiredPostProcessor: + def __init__(self, *args, **kwargs): + called["init"] += 1 + # will raise RuntimeError if no Flask app context exists + _ = current_app.name + + def invoke(self, *args, **kwargs): + called["invoke"] += 1 + _ = current_app.name + return kwargs.get("documents") or args[1] + + # output list from _multiple_retrieve_thread + all_documents: list[Document] = [] + + # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here + thread_exceptions: list[Exception] = [] + + def target(): + with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): + with patch( + "core.rag.retrieval.dataset_retrieval.DataPostProcessor", + ContextRequiredPostProcessor, + ): + dataset_retrieval._multiple_retrieve_thread( + flask_app=flask_app, + available_datasets=[mock_dataset, secondary_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={ + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-v2", + }, + weights=None, + top_k=3, + score_threshold=0.0, + query="test query", + attachment_id=None, + dataset_count=2, # force reranking branch + thread_exceptions=thread_exceptions, # ✅ key + ) + + t = threading.Thread(target=target) + t.start() + t.join() + + # Ensure reranking branch was actually executed + assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." + + # Current buggy code should record an exception (not raise it) + assert not thread_exceptions, thread_exceptions + + +class _FakeFlaskApp: + def app_context(self): + return nullcontext() + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class TestDatasetRetrievalAdditionalHelpers: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_llm_usage_and_record_usage(self, retrieval: DatasetRetrieval) -> None: + empty_usage = retrieval.llm_usage + assert empty_usage.total_tokens == 0 + + retrieval._record_usage(None) + assert retrieval.llm_usage.total_tokens == 0 + + usage_1 = LLMUsage.from_metadata({"prompt_tokens": 2, "completion_tokens": 3, "total_tokens": 5}) + usage_2 = LLMUsage.from_metadata({"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}) + retrieval._record_usage(usage_1) + retrieval._record_usage(usage_2) + assert retrieval.llm_usage.total_tokens == 10 + + def test_replace_metadata_filter_value(self, retrieval: DatasetRetrieval) -> None: + assert retrieval._replace_metadata_filter_value("plain", {}) == "plain" + replaced = retrieval._replace_metadata_filter_value( + "hello {{name}}\n\t{{missing}}", + {"name": "world"}, + ) + assert replaced == "hello world {{missing}}" + + def test_process_metadata_filter_in_with_scalar_fallback(self) -> None: + filters: list = [] + result = DatasetRetrieval.process_metadata_filter_func( + sequence=0, + condition="in", + metadata_name="category", + value=123, + filters=filters, + ) + assert result is filters + assert len(filters) == 1 + + def test_calculate_vector_score(self, retrieval: DatasetRetrieval) -> None: + doc_high = Document(page_content="a", metadata={"score": 0.9}, provider="dify") + doc_low = Document(page_content="b", metadata={"score": 0.2}, provider="dify") + doc_no_meta = Document(page_content="c", metadata={}, provider="dify") + + filtered = retrieval.calculate_vector_score([doc_low, doc_high, doc_no_meta], top_k=1, score_threshold=0.5) + assert len(filtered) == 1 + assert filtered[0].metadata["score"] == 0.9 + + assert retrieval.calculate_vector_score([doc_low], top_k=2, score_threshold=1.0) == [] + + def test_calculate_keyword_score(self, retrieval: DatasetRetrieval) -> None: + documents = [ + Document(page_content="python language", metadata={"doc_id": "1"}, provider="dify"), + Document(page_content="java language", metadata={"doc_id": "2"}, provider="dify"), + ] + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [ + ["python", "language"], + ["python", "language"], + ["java", "language"], + ] + + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("python language", documents, top_k=1) + + assert len(ranked) == 1 + assert "keywords" in ranked[0].metadata + assert ranked[0].metadata["doc_id"] == "1" + + def test_send_trace_task(self, retrieval: DatasetRetrieval) -> None: + trace_manager = Mock() + retrieval.application_generate_entity = SimpleNamespace(trace_manager=trace_manager) + docs = [Document(page_content="d", metadata={}, provider="dify")] + + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_called_once() + + retrieval.application_generate_entity = None + trace_manager.reset_mock() + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_not_called() + + def test_on_query(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query=None, + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="web", + user_id="u1", + ) + mock_session.add_all.assert_not_called() + + retrieval._on_query( + query="python", + attachment_ids=["f1"], + dataset_ids=["d1", "d2"], + app_id="a1", + user_from="web", + user_id="u1", + ) + mock_session.add_all.assert_called() + mock_session.commit.assert_called() + + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: + usage = LLMUsage.empty_usage() + chunk_1 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace(message=SimpleNamespace(content="hello "), usage=usage), + ) + chunk_2 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace( + message=SimpleNamespace(content=[SimpleNamespace(data="world")]), + usage=None, + ), + ) + text, returned_usage = retrieval._handle_invoke_result(iter([chunk_1, chunk_2])) + assert text == "hello world" + assert returned_usage == usage + + text_empty, usage_empty = retrieval._handle_invoke_result(iter([])) + assert text_empty == "" + assert usage_empty == LLMUsage.empty_usage() + + def test_get_prompt_template(self, retrieval: DatasetRetrieval) -> None: + model_config_chat = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=["x"], + ) + model_config_completion = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="completion", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + + with patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform") as mock_prompt_transform: + mock_prompt_transform.return_value.get_prompt.return_value = ["prompt"] + prompt_messages, stop = retrieval._get_prompt_template( + model_config=model_config_chat, + mode="chat", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages == ["prompt"] + assert stop == ["x"] + + with patch( + "core.rag.retrieval.dataset_retrieval.METADATA_FILTER_COMPLETION_PROMPT", + "{input_text} {metadata_fields}", + ): + prompt_messages_completion, stop_completion = retrieval._get_prompt_template( + model_config=model_config_completion, + mode="completion", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages_completion == ["prompt"] + assert stop_completion == [] + + with pytest.raises(ValueError): + retrieval._get_prompt_template( + model_config=model_config_chat, + mode="unknown-mode", + metadata_fields=[], + query="python", + ) + + def test_fetch_model_config_validation_and_success(self, retrieval: DatasetRetrieval) -> None: + with pytest.raises(ValueError, match="single_retrieval_config is required"): + retrieval._fetch_model_config("tenant-1", None) # type: ignore[arg-type] + + model_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat", completion_params={"stop": ["END"]}) + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance = Mock() + model_instance.model_type_instance.get_model_schema.return_value = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance + mock_cfg_entity.return_value = SimpleNamespace( + provider="openai", + model="gpt", + stop=["END"], + parameters={"temperature": 0.1}, + ) + + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model = SimpleNamespace(status=ModelStatus.NO_CONFIGURE) + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = provider_model + with pytest.raises(ValueError, match="credentials is not initialized"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.NO_PERMISSION + with pytest.raises(ValueError, match="currently not support"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.QUOTA_EXCEEDED + with pytest.raises(ValueError, match="quota exceeded"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.ACTIVE + bad_mode_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat") + bad_mode_cfg.mode = None # type: ignore[assignment] + with pytest.raises(ValueError, match="LLM mode is required"): + retrieval._fetch_model_config("tenant-1", bad_mode_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = Mock() + model_cfg_success = AppModelConfig( + provider="openai", + name="gpt", + mode="chat", + completion_params={"temperature": 0.1, "stop": ["END"]}, + ) + _, config = retrieval._fetch_model_config("tenant-1", model_cfg_success) + assert config.provider == "openai" + assert config.model == "gpt" + assert config.stop == ["END"] + assert "stop" not in config.parameters + + def test_automatic_metadata_filter_func(self, retrieval: DatasetRetrieval) -> None: + metadata_field = SimpleNamespace(name="author") + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([Mock()]) + model_config = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + session_scalars = Mock() + session_scalars.all.return_value = [metadata_field] + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, model_config)), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_handle_invoke_result", return_value=('{"metadata_map":[]}', usage)), + patch("core.rag.retrieval.dataset_retrieval.parse_and_check_json_markdown") as mock_parse, + patch.object(retrieval, "_record_usage") as mock_record_usage, + ): + mock_parse.return_value = { + "metadata_map": [ + { + "metadata_field_name": "author", + "metadata_field_value": "Alice", + "comparison_operator": "contains", + }, + { + "metadata_field_name": "ignored", + "metadata_field_value": "value", + "comparison_operator": "contains", + }, + ] + } + result = retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + assert result == [{"metadata_name": "author", "value": "Alice", "condition": "contains"}] + mock_record_usage.assert_called_once_with(usage) + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", side_effect=RuntimeError("boom")), + ): + with pytest.raises(RuntimeError, match="boom"): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None: + db_query = Mock() + db_query.where.return_value = db_query + db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="disabled", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + assert mapping is None + assert condition is None + + automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query), + patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters), + ): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="automatic", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=AppMetadataFilteringCondition(logical_operator="or", conditions=[]), + inputs={}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.logical_operator == "or" + + manual_conditions = AppMetadataFilteringCondition( + logical_operator="and", + conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")], + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="manual", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=manual_conditions, + inputs={"name": "Alice"}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.conditions[0].value == "Alice" + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with pytest.raises(ValueError, match="Invalid metadata filtering mode"): + retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="unsupported", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + + def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: + session = Mock() + subquery_query = Mock() + subquery_query.where.return_value = subquery_query + subquery_query.group_by.return_value = subquery_query + subquery_query.having.return_value = subquery_query + subquery_query.subquery.return_value = SimpleNamespace( + c=SimpleNamespace( + dataset_id=column("dataset_id"), available_document_count=column("available_document_count") + ) + ) + + dataset_query = Mock() + dataset_query.outerjoin.return_value = dataset_query + dataset_query.where.return_value = dataset_query + dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.query.side_effect = [subquery_query, dataset_query] + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx): + available = retrieval._get_available_datasets("tenant-1", ["d1", "d2"]) + + assert [dataset.id for dataset in available] == ["d1", "d2"] + + def test_check_knowledge_rate_limit(self, retrieval: DatasetRetrieval) -> None: + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=2, subscription_plan="pro") + mock_redis.zcard.return_value = 1 + retrieval._check_knowledge_rate_limit("tenant-1") + mock_redis.zadd.assert_called_once() + + session = Mock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=1, subscription_plan="pro") + mock_redis.zcard.return_value = 2 + with pytest.raises(exc.RateLimitExceededError): + retrieval._check_knowledge_rate_limit("tenant-1") + session.add.assert_called_once() + + with patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit: + mock_limit.return_value = SimpleNamespace(enabled=False) + retrieval._check_knowledge_rate_limit("tenant-1") + + +def _doc( + provider: str = "dify", + content: str = "content", + score: float = 0.9, + dataset_id: str = "dataset-1", + document_id: str = "document-1", + doc_id: str = "node-1", + extra: dict | None = None, +) -> Document: + metadata = { + "score": score, + "dataset_id": dataset_id, + "document_id": document_id, + "doc_id": doc_id, + } + if extra: + metadata.update(extra) + return Document(page_content=content, metadata=metadata, provider=provider) + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class _JoinDrivenThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._started = False + self._alive = False + + def start(self) -> None: + self._started = True + self._alive = True + + def join(self, timeout=None) -> None: + if self._started and self._alive and self._target: + self._target(**self._kwargs) + self._alive = False + + def is_alive(self) -> bool: + return self._alive + + +@contextmanager +def _timer(): + yield {"cost": 1} + + +class TestKnowledgeRetrievalCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_returns_empty_when_query_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query=None, + retrieval_mode="multiple", + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + assert retrieval.knowledge_retrieval(request) == [] + + def test_raises_when_metadata_model_config_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query="query", + retrieval_mode="multiple", + metadata_filtering_mode="automatic", + metadata_model_config=None, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + with pytest.raises(ValueError, match="metadata_model_config is required"): + retrieval.knowledge_retrieval(request) + + @pytest.mark.parametrize( + ("status", "error_cls"), + [ + (ModelStatus.NO_CONFIGURE, "ModelCredentialsNotInitializedError"), + (ModelStatus.NO_PERMISSION, "ModelNotSupportedError"), + (ModelStatus.QUOTA_EXCEEDED, "ModelQuotaExceededError"), + ], + ) + def test_single_mode_raises_for_model_status( + self, + retrieval: DatasetRetrieval, + status: ModelStatus, + error_cls: str, + ) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["dataset-1"], + query="python", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + ) + provider_model_bundle = Mock() + provider_model_bundle.configuration.get_provider_model.return_value = SimpleNamespace(status=status) + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = Mock() + model_instance = SimpleNamespace( + provider_model_bundle=provider_model_bundle, + model_type_instance=model_type_instance, + credentials={}, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + ): + mock_model_manager.return_value.get_model_instance.return_value = model_instance + with pytest.raises(Exception) as exc_info: + retrieval.knowledge_retrieval(request) + assert error_cls in type(exc_info.value).__name__ + + +class TestRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def _build_model_config(self, features: list[ModelFeature] | None = None): + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = SimpleNamespace(features=features or []) + provider_bundle = SimpleNamespace(model_type_instance=model_type_instance) + return ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt-4", + model_schema=Mock(), + mode="chat", + provider_model_bundle=provider_bundle, + credentials={}, + parameters={}, + stop=[], + ) + + def test_returns_none_when_dataset_ids_empty(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=[], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=self._build_model_config(), + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_returns_none_when_model_schema_missing(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + model_config = self._build_model_config() + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = Mock() + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config() + external_doc = _doc( + provider="external", + content="external content", + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External", "dataset_name": "External DS"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[external_doc]), + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert context == "external content" + assert files == [] + + def test_multiple_strategy_with_vision_and_source_details(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=4, + score_threshold=0.1, + rerank_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + reranking_enabled=True, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config(features=[ModelFeature.TOOL_CALL]) + external_doc = _doc( + provider="external", + content="external body", + score=0.8, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External Title", "dataset_name": "External DS"}, + ) + dify_doc = _doc( + provider="dify", + content="dify body", + score=0.9, + dataset_id="d1", + document_id="doc-1", + doc_id="node-1", + ) + record = SimpleNamespace( + segment=SimpleNamespace( + id="segment-1", + dataset_id="d1", + document_id="doc-1", + tenant_id="tenant-1", + hit_count=3, + word_count=11, + position=1, + index_node_hash="hash-1", + content="segment content", + answer="segment answer", + get_sign_content=lambda: "segment content", + ), + score=0.9, + summary="short summary", + files=None, + ) + dataset_item = SimpleNamespace(id="d1", name="Dataset One") + document_item = SimpleNamespace( + id="doc-1", + name="Document One", + data_source_type="upload_file", + doc_metadata={"lang": "en"}, + ) + upload_file = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="https://example.com/img.png", + size=123, + key="k1", + ) + execute_attachments = SimpleNamespace(all=lambda: [(SimpleNamespace(), upload_file)]) + execute_docs = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [document_item])) + execute_datasets = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [dataset_item])) + hit_callback = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.format_retrieval_documents", + return_value=[record], + ), + patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), + patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.DEBUGGER, + show_retrieve_source=True, + hit_callback=hit_callback, + message_id="m1", + vision_enabled=True, + ) + + assert "short summary" in (context or "") + assert "question:segment content answer:segment answer" in (context or "") + assert len(files or []) == 1 + hit_callback.return_retriever_resource_info.assert_called_once() + + +class TestSingleAndMultipleRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_single_retrieve_external_path(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="External DS", + description=None, + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end") as mock_end, + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + mock_external.return_value = [ + {"content": "ext result", "metadata": {"k": "v"}, "score": 0.9, "title": "Ext Doc"} + ] + result = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + message_id="m1", + ) + + assert len(result) == 1 + assert result[0].provider == "external" + mock_end.assert_called_once() + assert retrieval.llm_usage.total_tokens == 2 + + def test_single_retrieve_dify_path_and_filters(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="dataset desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "cohere", "reranking_model_name": "rerank"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + "top_k": 3, + "score_threshold_enabled": True, + "score_threshold": 0.2, + }, + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}) + result_doc = _doc(provider="dify", score=0.7, dataset_id="ds-1", document_id="doc-1", doc_id="node-1") + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.FunctionCallMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[result_doc] + ) as mock_retrieve, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end"), + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.ROUTER, + metadata_filter_document_ids={"ds-1": ["doc-1"]}, + metadata_condition=SimpleNamespace(), + ) + + assert results == [result_doc] + assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc-1"] + assert retrieval.llm_usage.total_tokens == 1 + + def test_single_retrieve_returns_empty_when_no_dataset_selected(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls: + mock_router_cls.return_value.invoke.return_value = (None, LLMUsage.empty_usage()) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[ + SimpleNamespace(id="ds-1", name="DS", description=None), + ], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + ) + assert results == [] + + def test_single_retrieve_respects_metadata_filter_shortcuts(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={"top_k": 2, "search_method": "semantic_search", "reranking_enable": False}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", LLMUsage.empty_usage()) + no_filter = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids=None, + metadata_condition=SimpleNamespace(), + ) + missing_doc_ids = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids={"other-ds": ["x"]}, + metadata_condition=None, + ) + assert no_filter == [] + assert missing_doc_ids == [] + + def test_multiple_retrieve_validation_paths(self, retrieval: DatasetRetrieval) -> None: + assert ( + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=[], + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + == [] + ) + + mixed = [ + SimpleNamespace(id="d1", indexing_technique="high_quality"), + SimpleNamespace(id="d2", indexing_technique="economy"), + ] + with pytest.raises(ValueError, match="different indexing technique"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=mixed, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="weighted_score", + reranking_enable=False, + ) + + high_quality_mismatch = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-b", + embedding_model_provider="provider-b", + ), + ] + with pytest.raises(ValueError, match="different embedding model"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=high_quality_mismatch, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + ) + + def test_multiple_retrieve_threads_and_dedup(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + ] + doc_a = _doc(provider="dify", score=0.8, dataset_id="d1", document_id="doc-1", doc_id="dup") + doc_b = _doc(provider="dify", score=0.7, dataset_id="d2", document_id="doc-2", doc_id="dup") + doc_external = _doc( + provider="external", + score=0.9, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"dataset_name": "Ext", "title": "Ext"}, + ) + app = Flask(__name__) + weights = {"vector_setting": {}} + + def fake_multiple_thread(**kwargs): + if kwargs["query"]: + kwargs["all_documents"].extend([doc_a, doc_b]) + if kwargs["attachment_id"]: + kwargs["all_documents"].append(doc_external) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=fake_multiple_thread), + patch.object(retrieval, "_on_query") as mock_on_query, + patch.object(retrieval, "_on_retrieval_end") as mock_end, + ): + result = retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + weights=weights, + attachment_ids=["att-1"], + message_id="m1", + ) + + assert len(result) == 2 + assert any(doc.provider == "external" for doc in result) + assert weights["vector_setting"]["embedding_provider_name"] == "provider-a" + assert weights["vector_setting"]["embedding_model_name"] == "model-a" + mock_on_query.assert_called_once() + mock_end.assert_called_once() + + def test_multiple_retrieve_propagates_thread_exception(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ) + ] + app = Flask(__name__) + + def failing_thread(**kwargs): + kwargs["thread_exceptions"].append(RuntimeError("thread boom")) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=failing_thread), + ): + with pytest.raises(RuntimeError, match="thread boom"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + + +class TestInternalHooksCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_on_retrieval_end_without_dify_documents(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + with patch.object(retrieval, "_send_trace_task") as mock_trace: + retrieval._on_retrieval_end( + flask_app=app, + documents=[_doc(provider="external")], + message_id="m1", + timer={"cost": 1}, + ) + mock_trace.assert_called_once() + + def test_on_retrieval_end_dify_without_document_ids(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + doc = Document(page_content="x", metadata={"doc_id": "n1"}, provider="dify") + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=[doc], message_id="m1", timer={"cost": 1}) + mock_trace.assert_called_once() + + def test_on_retrieval_end_updates_segments_for_text_and_image(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + docs = [ + _doc(provider="dify", document_id="doc-a", doc_id="idx-a", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-b", doc_id="att-b", extra={"doc_type": DocType.IMAGE}), + _doc(provider="dify", document_id="doc-c", doc_id="idx-c", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-d", doc_id="att-d", extra={"doc_type": DocType.IMAGE}), + ] + dataset_docs = [ + SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-c", doc_form="qa_model"), + SimpleNamespace(id="doc-d", doc_form="qa_model"), + ] + child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] + segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] + bindings = [SimpleNamespace(segment_id="seg-b"), SimpleNamespace(segment_id="seg-d")] + + def _scalars(items): + result = Mock() + result.all.return_value = items + return result + + session = Mock() + session.scalars.side_effect = [ + _scalars(dataset_docs), + _scalars(child_chunks), + _scalars(segments), + _scalars(bindings), + ] + query = Mock() + query.where.return_value = query + session.query.return_value = query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) + + query.update.assert_called_once() + session.commit.assert_called_once() + mock_trace.assert_called_once() + + def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: + flask_app = SimpleNamespace(app_context=lambda: nullcontext()) + all_documents: list[Document] = [] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=None): + assert ( + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="d1", + query="python", + top_k=1, + all_documents=all_documents, + ) + == [] + ) + + external_dataset = SimpleNamespace( + id="ext-ds", + name="External", + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=external_dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + ): + mock_external.return_value = [{"content": "e", "metadata": {}, "score": 0.8, "title": "Ext"}] + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="ext-ds", + query="python", + top_k=1, + all_documents=all_documents, + ) + + economy_dataset = SimpleNamespace( + id="eco-ds", + provider="dify", + retrieval_model={"top_k": 1}, + indexing_technique="economy", + ) + high_dataset = SimpleNamespace( + id="hq-ds", + provider="dify", + retrieval_model={ + "search_method": "semantic_search", + "top_k": 4, + "score_threshold": 0.3, + "score_threshold_enabled": True, + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "x", "reranking_model_name": "y"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + }, + indexing_technique="high_quality", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", side_effect=[economy_dataset, high_dataset] + ), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[_doc(provider="dify")] + ) as mock_retrieve, + ): + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="eco-ds", + query="python", + top_k=2, + all_documents=all_documents, + ) + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="hq-ds", + query="python", + top_k=2, + all_documents=all_documents, + attachment_ids=["att-1"], + ) + assert mock_retrieve.call_count == 2 + assert len(all_documents) >= 3 + + def test_to_dataset_retriever_tool_paths(self, retrieval: DatasetRetrieval) -> None: + dataset_skip_zero = SimpleNamespace(id="d1", provider="dify", available_document_count=0) + dataset_ok_single = SimpleNamespace( + id="d2", + provider="dify", + available_document_count=2, + retrieval_model={"top_k": 2, "score_threshold_enabled": True, "score_threshold": 0.1}, + ) + single_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", + side_effect=[None, dataset_skip_zero, dataset_ok_single], + ), + patch( + "core.tools.utils.dataset_retriever.dataset_retriever_tool.DatasetRetrieverTool.from_dataset", + return_value="single-tool", + ) as mock_single_tool, + ): + single_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["missing", "d1", "d2"], + retrieve_config=single_config, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={"k": "v"}, + ) + + assert single_tools == ["single-tool"] + mock_single_tool.assert_called_once() + + multiple_config_missing = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + reranking_model=None, + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single): + with pytest.raises(ValueError, match="Reranking model is required"): + retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config_missing, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + + multiple_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + top_k=3, + score_threshold=0.2, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single), + patch( + "core.tools.utils.dataset_retriever.dataset_multi_retriever_tool.DatasetMultiRetrieverTool.from_dataset", + return_value="multi-tool", + ) as mock_multi_tool, + ): + multi_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config, + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + assert multi_tools == ["multi-tool"] + mock_multi_tool.assert_called_once() + + def test_additional_small_branches(self, retrieval: DatasetRetrieval) -> None: + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [[], []] + doc = Document(page_content="doc", metadata={"doc_id": "1"}, provider="dify") + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("query", [doc], top_k=1) + assert len(ranked) == 1 + assert ranked[0].metadata.get("score") == 0.0 + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + with pytest.raises(ValueError): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=None, # type: ignore[arg-type] + ) + + session_scalars = Mock() + session_scalars.all.return_value = [SimpleNamespace(name="author")] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(Mock(), Mock())), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_record_usage"), + ): + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("nope") + with patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, Mock())): + assert ( + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=WorkflowModelConfig(provider="openai", name="gpt", mode="chat"), + ) + is None + ) + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelMode", return_value=object()), + patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform"), + ): + with pytest.raises(ValueError, match="not support"): + retrieval._get_prompt_template( + model_config=ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ), + mode="chat", + metadata_fields=[], + query="q", + ) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py deleted file mode 100644 index 07d6e51e4b..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py +++ /dev/null @@ -1,873 +0,0 @@ -""" -Unit tests for DatasetRetrieval.process_metadata_filter_func. - -This module provides comprehensive test coverage for the process_metadata_filter_func -method in the DatasetRetrieval class, which is responsible for building SQLAlchemy -filter expressions based on metadata filtering conditions. - -Conditions Tested: -================== -1. **String Conditions**: contains, not contains, start with, end with -2. **Equality Conditions**: is / =, is not / ≠ -3. **Null Conditions**: empty, not empty -4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >= -5. **List Conditions**: in -6. **Edge Cases**: None values, different data types (str, int, float) - -Test Architecture: -================== -- Direct instantiation of DatasetRetrieval -- Mocking of DatasetDocument model attributes -- Verification of SQLAlchemy filter expressions -- Follows Arrange-Act-Assert (AAA) pattern - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\ -TestProcessMetadataFilterFunc::test_contains_condition -v -""" - -from unittest.mock import MagicMock - -import pytest - -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval - - -class TestProcessMetadataFilterFunc: - """ - Comprehensive test suite for process_metadata_filter_func method. - - This test class validates all metadata filtering conditions supported by - the DatasetRetrieval class, including string operations, numeric comparisons, - null checks, and list operations. - - Method Signature: - ================== - def process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list - ) -> list: - - The method builds SQLAlchemy filter expressions by: - 1. Validating value is not None (except for empty/not empty conditions) - 2. Using DatasetDocument.doc_metadata JSON field operations - 3. Adding appropriate SQLAlchemy expressions to the filters list - 4. Returning the updated filters list - - Mocking Strategy: - ================== - - Mock DatasetDocument.doc_metadata to avoid database dependencies - - Verify filter expressions are created correctly - - Test with various data types (str, int, float, list) - """ - - @pytest.fixture - def retrieval(self): - """ - Create a DatasetRetrieval instance for testing. - - Returns: - DatasetRetrieval: Instance to test process_metadata_filter_func - """ - return DatasetRetrieval() - - @pytest.fixture - def mock_doc_metadata(self): - """ - Mock the DatasetDocument.doc_metadata JSON field. - - The method uses DatasetDocument.doc_metadata[metadata_name] to access - JSON fields. We mock this to avoid database dependencies. - - Returns: - Mock: Mocked doc_metadata attribute - """ - mock_metadata_field = MagicMock() - - # Create mock for string access - mock_string_access = MagicMock() - mock_string_access.like = MagicMock() - mock_string_access.notlike = MagicMock() - mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_string_access.in_ = MagicMock(return_value=MagicMock()) - - # Create mock for float access (for numeric comparisons) - mock_float_access = MagicMock() - mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__le__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) - - # Create mock for null checks - mock_null_access = MagicMock() - mock_null_access.is_ = MagicMock(return_value=MagicMock()) - mock_null_access.isnot = MagicMock(return_value=MagicMock()) - - # Setup __getitem__ to return appropriate mock based on usage - def getitem_side_effect(name): - if name in ["author", "title", "category"]: - return mock_string_access - elif name in ["year", "price", "rating"]: - return mock_float_access - else: - return mock_string_access - - mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) - mock_metadata_field.as_string.return_value = mock_string_access - mock_metadata_field.as_float.return_value = mock_float_access - mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ - mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot - - return mock_metadata_field - - # ==================== String Condition Tests ==================== - - def test_contains_condition_string_value(self, retrieval): - """ - Test 'contains' condition with string value. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value% syntax - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "John" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_contains_condition(self, retrieval): - """ - Test 'not contains' condition. - - Verifies: - - Filters list is populated with NOT LIKE expression - - Pattern matching uses %value% syntax with negation - """ - filters = [] - sequence = 0 - condition = "not contains" - metadata_name = "title" - value = "banned" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_start_with_condition(self, retrieval): - """ - Test 'start with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses value% syntax - """ - filters = [] - sequence = 0 - condition = "start with" - metadata_name = "category" - value = "tech" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_end_with_condition(self, retrieval): - """ - Test 'end with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value syntax - """ - filters = [] - sequence = 0 - condition = "end with" - metadata_name = "filename" - value = ".pdf" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Equality Condition Tests ==================== - - def test_is_condition_with_string_value(self, retrieval): - """ - Test 'is' (=) condition with string value. - - Verifies: - - Filters list is populated with equality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = "Jane Doe" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_equals_condition_with_string_value(self, retrieval): - """ - Test '=' condition with string value. - - Verifies: - - Same behavior as 'is' condition - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "=" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_int_value(self, retrieval): - """ - Test 'is' condition with integer value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_float_value(self, retrieval): - """ - Test 'is' condition with float value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "price" - value = 19.99 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_string_value(self, retrieval): - """ - Test 'is not' (≠) condition with string value. - - Verifies: - - Filters list is populated with inequality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "author" - value = "Unknown" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_equals_condition(self, retrieval): - """ - Test '≠' condition with string value. - - Verifies: - - Same behavior as 'is not' condition - - Inequality expression is used - """ - filters = [] - sequence = 0 - condition = "≠" - metadata_name = "category" - value = "archived" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_numeric_value(self, retrieval): - """ - Test 'is not' condition with numeric value. - - Verifies: - - Numeric inequality comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Null Condition Tests ==================== - - def test_empty_condition(self, retrieval): - """ - Test 'empty' condition (null check). - - Verifies: - - Filters list is populated with IS NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "empty" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_empty_condition(self, retrieval): - """ - Test 'not empty' condition (not null check). - - Verifies: - - Filters list is populated with IS NOT NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "not empty" - metadata_name = "description" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Numeric Comparison Tests ==================== - - def test_before_condition(self, retrieval): - """ - Test 'before' (<) condition. - - Verifies: - - Filters list is populated with less than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "before" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_condition(self, retrieval): - """ - Test '<' condition. - - Verifies: - - Same behavior as 'before' condition - - Less than expression is used - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "price" - value = 100.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_after_condition(self, retrieval): - """ - Test 'after' (>) condition. - - Verifies: - - Filters list is populated with greater than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "after" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_condition(self, retrieval): - """ - Test '>' condition. - - Verifies: - - Same behavior as 'after' condition - - Greater than expression is used - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≤' condition. - - Verifies: - - Filters list is populated with less than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≤" - metadata_name = "price" - value = 50.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_ascii(self, retrieval): - """ - Test '<=' condition. - - Verifies: - - Same behavior as '≤' condition - - Less than or equal expression is used - """ - filters = [] - sequence = 0 - condition = "<=" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≥' condition. - - Verifies: - - Filters list is populated with greater than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≥" - metadata_name = "rating" - value = 3.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_ascii(self, retrieval): - """ - Test '>=' condition. - - Verifies: - - Same behavior as '≥' condition - - Greater than or equal expression is used - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== List/In Condition Tests ==================== - - def test_in_condition_with_comma_separated_string(self, retrieval): - """ - Test 'in' condition with comma-separated string value. - - Verifies: - - String is split into list - - Whitespace is trimmed from each value - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "tech, science, AI " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_list_value(self, retrieval): - """ - Test 'in' condition with list value. - - Verifies: - - List is processed correctly - - None values are filtered out - - IN expression is created with valid values - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "tags" - value = ["python", "javascript", None, "golang"] - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_tuple_value(self, retrieval): - """ - Test 'in' condition with tuple value. - - Verifies: - - Tuple is processed like a list - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = ("tech", "science", "ai") - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_empty_string(self, retrieval): - """ - Test 'in' condition with empty string value. - - Verifies: - - Empty string results in literal(False) filter - - No valid values to match - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - # Verify it's a literal(False) expression - # This is a bit tricky to test without access to the actual expression - - def test_in_condition_with_only_whitespace(self, retrieval): - """ - Test 'in' condition with whitespace-only string value. - - Verifies: - - Whitespace-only string results in literal(False) filter - - All values are stripped and filtered out - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = " , , " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_single_string(self, retrieval): - """ - Test 'in' condition with single non-comma string. - - Verifies: - - Single string is treated as single-item list - - IN expression is created with one value - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Edge Case Tests ==================== - - def test_none_value_with_non_empty_condition(self, retrieval): - """ - Test None value with conditions that require value. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values (except empty/not empty) - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 # No filter added - - def test_none_value_with_equals_condition(self, retrieval): - """ - Test None value with 'is' (=) condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_none_value_with_numeric_condition(self, retrieval): - """ - Test None value with numeric comparison condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "year" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_existing_filters_preserved(self, retrieval): - """ - Test that existing filters are preserved. - - Verifies: - - Existing filters in the list are not removed - - New filters are appended to the list - """ - existing_filter = MagicMock() - filters = [existing_filter] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 2 - assert filters[0] == existing_filter - - def test_multiple_filters_accumulated(self, retrieval): - """ - Test multiple calls to accumulate filters. - - Verifies: - - Each call adds a new filter to the list - - All filters are preserved across calls - """ - filters = [] - - # First filter - retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) - assert len(filters) == 1 - - # Second filter - retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) - assert len(filters) == 2 - - # Third filter - retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) - assert len(filters) == 3 - - def test_unknown_condition(self, retrieval): - """ - Test unknown/unsupported condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for unknown conditions - """ - filters = [] - sequence = 0 - condition = "unknown_condition" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_empty_string_value_with_contains(self, retrieval): - """ - Test empty string value with 'contains' condition. - - Verifies: - - Filter is added even with empty string - - LIKE expression is created - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_special_characters_in_value(self, retrieval): - """ - Test special characters in value string. - - Verifies: - - Special characters are handled in value - - LIKE expression is created correctly - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "title" - value = "C++ & Python's features" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_zero_value_with_numeric_condition(self, retrieval): - """ - Test zero value with numeric comparison condition. - - Verifies: - - Zero is treated as valid value - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "price" - value = 0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_negative_value_with_numeric_condition(self, retrieval): - """ - Test negative value with numeric comparison condition. - - Verifies: - - Negative numbers are handled correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "temperature" - value = -10.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_float_value_with_integer_comparison(self, retrieval): - """ - Test float value with numeric comparison condition. - - Verifies: - - Float values work correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 diff --git a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py deleted file mode 100644 index 5f461d53ae..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py +++ /dev/null @@ -1,113 +0,0 @@ -import threading -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest -from flask import Flask, current_app - -from core.rag.models.document import Document -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from models.dataset import Dataset - - -class TestRetrievalService: - @pytest.fixture - def mock_dataset(self) -> Dataset: - dataset = Mock(spec=Dataset) - dataset.id = str(uuid4()) - dataset.tenant_id = str(uuid4()) - dataset.name = "test_dataset" - dataset.indexing_technique = "high_quality" - dataset.provider = "dify" - return dataset - - def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): - """ - Repro test for current bug: - reranking runs after `with flask_app.app_context():` exits. - `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, - so we must assert from that list (not from an outer try/except). - """ - dataset_retrieval = DatasetRetrieval() - flask_app = Flask(__name__) - tenant_id = str(uuid4()) - - # second dataset to ensure dataset_count > 1 reranking branch - secondary_dataset = Mock(spec=Dataset) - secondary_dataset.id = str(uuid4()) - secondary_dataset.provider = "dify" - secondary_dataset.indexing_technique = "high_quality" - - # retriever returns 1 doc into internal list (all_documents_item) - document = Document( - page_content="Context aware doc", - metadata={ - "doc_id": "doc1", - "score": 0.95, - "document_id": str(uuid4()), - "dataset_id": mock_dataset.id, - }, - provider="dify", - ) - - def fake_retriever( - flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids - ): - all_documents.append(document) - - called = {"init": 0, "invoke": 0} - - class ContextRequiredPostProcessor: - def __init__(self, *args, **kwargs): - called["init"] += 1 - # will raise RuntimeError if no Flask app context exists - _ = current_app.name - - def invoke(self, *args, **kwargs): - called["invoke"] += 1 - _ = current_app.name - return kwargs.get("documents") or args[1] - - # output list from _multiple_retrieve_thread - all_documents: list[Document] = [] - - # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here - thread_exceptions: list[Exception] = [] - - def target(): - with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): - with patch( - "core.rag.retrieval.dataset_retrieval.DataPostProcessor", - ContextRequiredPostProcessor, - ): - dataset_retrieval._multiple_retrieve_thread( - flask_app=flask_app, - available_datasets=[mock_dataset, secondary_dataset], - metadata_condition=None, - metadata_filter_document_ids=None, - all_documents=all_documents, - tenant_id=tenant_id, - reranking_enable=True, - reranking_mode="reranking_model", - reranking_model={ - "reranking_provider_name": "cohere", - "reranking_model_name": "rerank-v2", - }, - weights=None, - top_k=3, - score_threshold=0.0, - query="test query", - attachment_id=None, - dataset_count=2, # force reranking branch - thread_exceptions=thread_exceptions, # ✅ key - ) - - t = threading.Thread(target=target) - t.start() - t.join() - - # Ensure reranking branch was actually executed - assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." - - # Current buggy code should record an exception (not raise it) - assert not thread_exceptions, thread_exceptions diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py new file mode 100644 index 0000000000..cfa9094e12 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -0,0 +1,100 @@ +from unittest.mock import Mock + +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage + + +class TestFunctionCallMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = FunctionCallMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_model_response(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + usage = LLMUsage.empty_usage() + response = Mock() + response.usage = usage + response.message.tool_calls = [Mock(function=Mock())] + response.message.tool_calls[0].function.name = "dataset-2" + model_instance = Mock() + model_instance.invoke_llm.return_value = response + + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id == "dataset-2" + assert returned_usage == usage + model_instance.invoke_llm.assert_called_once() + + def test_invoke_returns_none_when_no_tool_calls(self) -> None: + router = FunctionCallMultiDatasetRouter() + response = Mock() + response.usage = LLMUsage.empty_usage() + response.message.tool_calls = [] + model_instance = Mock() + model_instance.invoke_llm.return_value = response + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == response.usage + + def test_invoke_returns_empty_usage_when_model_raises(self) -> None: + router = FunctionCallMultiDatasetRouter() + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("boom") + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py new file mode 100644 index 0000000000..e429563739 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -0,0 +1,252 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole + + +class TestReactMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = ReactMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = ReactMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_react_invoke(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + tool_1 = Mock(name="dataset-1") + tool_1.name = "dataset-1" + tool_2 = Mock(name="dataset-2") + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", return_value=("dataset-2", usage)) as mock_react: + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + mock_react.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_invoke_handles_react_invoke_errors(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", side_effect=RuntimeError("boom")): + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_react_invoke_returns_action_tool(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "chat" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] + tools[0].name = "dataset-1" + tools[0].description = "desc" + tools[1].name = "dataset-2" + tools[1].description = "desc" + + with ( + patch.object(router, "create_chat_prompt", return_value=[Mock()]) as mock_chat_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object(router, "_invoke_llm", return_value=('{"action":"dataset-2","action_input":{}}', usage)), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactAction("dataset-2", {}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=tools, + user_id="u1", + tenant_id="t1", + ) + + mock_chat_prompt.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_react_invoke_returns_none_for_finish(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "completion" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, "_invoke_llm", return_value=('{"action":"Final Answer","action_input":"done"}', usage) + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert returned_usage == usage + + def test_invoke_llm_and_handle_result(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=usage) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([chunk]) + + with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + text, returned_usage = router._invoke_llm( + completion_param={"temperature": 0.1}, + model_instance=model_instance, + prompt_messages=[Mock()], + stop=["Observation:"], + user_id="u1", + tenant_id="t1", + ) + + assert text == "part" + assert returned_usage == usage + mock_deduct.assert_called_once() + + def test_handle_invoke_result_with_empty_usage(self) -> None: + router = ReactMultiDatasetRouter() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=None) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + + text, usage = router._handle_invoke_result(iter([chunk])) + + assert text == "part" + assert usage == LLMUsage.empty_usage() + + def test_create_chat_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + chat_prompt = router.create_chat_prompt(query="python", tools=[tool_1, tool_2]) + assert len(chat_prompt) == 2 + assert chat_prompt[0].role == PromptMessageRole.SYSTEM + assert chat_prompt[1].role == PromptMessageRole.USER + assert "dataset-1" in chat_prompt[0].text + assert "dataset-2" in chat_prompt[0].text + + def test_create_completion_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + completion_prompt = router.create_completion_prompt(tools=[tool_1, tool_2]) + assert "dataset-1: d1" in completion_prompt.text + assert "dataset-2: d2" in completion_prompt.text + + def test_react_invoke_uses_completion_branch_for_non_chat_mode(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "unknown-mode" + model_config.parameters = {} + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, + "_invoke_llm", + return_value=('{"action":"Final Answer","action_input":"done"}', LLMUsage.empty_usage()), + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + dataset_id, usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py new file mode 100644 index 0000000000..c8fa0ea62f --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py @@ -0,0 +1,69 @@ +import pytest + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser + + +class TestStructuredChatOutputParser: + def test_parse_action_without_action_input(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"some_action"}\n```' + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "some_action" + assert result.tool_input == {} + + def test_parse_json_without_action_key(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"not_action":"search"}\n```' + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) + + def test_parse_returns_action_for_tool_call(self) -> None: + parser = StructuredChatOutputParser() + text = ( + 'Thought: call tool\nAction:\n```json\n{"action":"search_dataset","action_input":{"query":"python"}}\n```' + ) + + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "search_dataset" + assert result.tool_input == {"query": "python"} + assert result.log == text + + def test_parse_returns_finish_for_final_answer(self) -> None: + parser = StructuredChatOutputParser() + text = 'Thought: done\nAction:\n```json\n{"action":"Final Answer","action_input":"final text"}\n```' + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": "final text"} + assert result.log == text + + def test_parse_returns_finish_for_json_array_payload(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n[{"action":"search","action_input":"hello"}]\n```' + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + assert result.log == text + + def test_parse_returns_finish_for_plain_text(self) -> None: + parser = StructuredChatOutputParser() + text = "No structured action block" + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + + def test_parse_raises_value_error_for_invalid_json(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"search","action_input": }\n```' + + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 943a9e5712..976de10d89 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -125,7 +125,11 @@ Run with coverage: - Tests are organized by functionality in classes for better organization """ +import asyncio import string +import sys +import types +from inspect import currentframe from unittest.mock import Mock, patch import pytest @@ -604,6 +608,51 @@ class TestRecursiveCharacterTextSplitter: assert "def hello_world" in combined or "hello_world" in combined +class TestTextSplitterBasePaths: + """Target uncovered base TextSplitter paths.""" + + def test_from_huggingface_tokenizer_success_path(self): + """Cover from_huggingface_tokenizer success branch with mocked transformers.""" + + class _FakePreTrainedTokenizerBase: + pass + + class _FakeTokenizer(_FakePreTrainedTokenizerBase): + def encode(self, text: str): + return [ord(c) for c in text] + + fake_transformers = types.SimpleNamespace(PreTrainedTokenizerBase=_FakePreTrainedTokenizerBase) + with patch.dict(sys.modules, {"transformers": fake_transformers}): + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer=_FakeTokenizer(), + chunk_size=5, + chunk_overlap=1, + ) + + chunks = splitter.split_text("abcdef") + assert chunks + + def test_from_huggingface_tokenizer_import_error(self): + """Cover from_huggingface_tokenizer import-error branch.""" + with patch.dict(sys.modules, {"transformers": None}): + with pytest.raises(ValueError, match="Could not import transformers"): + RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer=object(), chunk_size=5) + + def test_atransform_documents_raises_not_implemented(self): + """Cover atransform_documents NotImplemented branch.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=20, chunk_overlap=5) + with pytest.raises(NotImplementedError): + asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})])) + + def test_merge_splits_logs_warning_for_oversized_total(self): + """Cover logger.warning path in _merge_splits.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1) + with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning: + merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1]) + assert merged + mock_warning.assert_called_once() + + # ============================================================================ # Test TokenTextSplitter # ============================================================================ @@ -662,6 +711,44 @@ class TestTokenTextSplitter: except ImportError: pytest.skip("tiktoken not installed") + def test_initialization_and_split_with_mocked_tiktoken_encoding(self): + """Cover TokenTextSplitter __init__ else-path and split_text logic.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_tiktoken = types.SimpleNamespace(get_encoding=lambda name: _FakeEncoding()) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(encoding_name="gpt2", chunk_size=4, chunk_overlap=1) + result = splitter.split_text("abcdefgh") + + assert result + assert all(isinstance(chunk, str) for chunk in result) + + def test_initialization_with_model_name_uses_encoding_for_model(self): + """Cover TokenTextSplitter model_name init branch.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_encoding = _FakeEncoding() + fake_tiktoken = types.SimpleNamespace( + encoding_for_model=lambda model_name: fake_encoding, + get_encoding=lambda name: _FakeEncoding(), + ) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(model_name="gpt-4", chunk_size=5, chunk_overlap=1) + + assert splitter._tokenizer is fake_encoding + # ============================================================================ # Test EnhanceRecursiveCharacterTextSplitter @@ -731,6 +818,50 @@ class TestEnhanceRecursiveCharacterTextSplitter: assert len(result) > 0 assert all(isinstance(chunk, str) for chunk in result) + def test_from_encoder_internal_token_encoder_paths(self): + """ + Test internal _token_encoder branches by capturing local closure from frame. + + This validates: + - empty texts path + - embedding model path + - GPT2Tokenizer fallback path + - _character_encoder empty-path branch + """ + + class _SpySplitter(EnhanceRecursiveCharacterTextSplitter): + captured_token_encoder = None + captured_character_encoder = None + + def __init__(self, **kwargs): + frame = currentframe() + if frame and frame.f_back: + _SpySplitter.captured_token_encoder = frame.f_back.f_locals.get("_token_encoder") + _SpySplitter.captured_character_encoder = frame.f_back.f_locals.get("_character_encoder") + super().__init__(**kwargs) + + mock_model = Mock() + mock_model.get_text_embedding_num_tokens.return_value = [3, 5] + + _SpySplitter.from_encoder(embedding_model_instance=mock_model, chunk_size=10, chunk_overlap=1) + token_encoder = _SpySplitter.captured_token_encoder + character_encoder = _SpySplitter.captured_character_encoder + + assert token_encoder is not None + assert character_encoder is not None + assert token_encoder([]) == [] + assert token_encoder(["abc", "defgh"]) == [3, 5] + assert character_encoder([]) == [] + + with patch( + "core.rag.splitter.fixed_text_splitter.GPT2Tokenizer.get_num_tokens", + side_effect=lambda text: len(text) + 1, + ): + _SpySplitter.from_encoder(embedding_model_instance=None, chunk_size=10, chunk_overlap=1) + token_encoder_without_model = _SpySplitter.captured_token_encoder + assert token_encoder_without_model is not None + assert token_encoder_without_model(["ab", "cdef"]) == [3, 5] + # ============================================================================ # Test FixedRecursiveCharacterTextSplitter @@ -908,6 +1039,56 @@ class TestFixedRecursiveCharacterTextSplitter: chunks = splitter.split_text(data) assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + def test_recursive_split_keep_separator_and_recursive_fallback(self): + """Cover keep-separator split branch and recursive _split_text fallback.""" + text = "short." + ("x" * 60) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=[".", " ", ""], + chunk_size=10, + chunk_overlap=2, + keep_separator=True, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert any("short." in chunk for chunk in chunks) + assert any(len(chunk) <= 12 for chunk in chunks) + + def test_recursive_split_newline_separator_filtering(self): + """Cover newline-specific empty filtering branch.""" + text = "line1\n\nline2\n\nline3" + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n", ""], + chunk_size=50, + chunk_overlap=5, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert all(chunk != "" for chunk in chunks) + assert "line1" in "".join(chunks) + assert "line2" in "".join(chunks) + assert "line3" in "".join(chunks) + + def test_recursive_split_without_new_separator_appends_long_chunk(self): + """Cover branch where no further separators exist and long split is appended directly.""" + text = "aa\n" + ("b" * 40) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n"], + chunk_size=10, + chunk_overlap=2, + ) + + chunks = splitter.recursive_split_text(text) + + assert "aa" in chunks + assert any(len(chunk) >= 40 for chunk in chunks) + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py new file mode 100644 index 0000000000..c66e50437a --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -0,0 +1,84 @@ +from datetime import datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType +from models import Account, WorkflowRun +from models.enums import WorkflowRunTriggeredFrom + + +def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository: + engine = create_engine("sqlite:///:memory:") + real_session_factory = sessionmaker(bind=engine, expire_on_commit=False) + + user = MagicMock(spec=Account) + user.id = str(uuid4()) + user.current_tenant_id = str(uuid4()) + + repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=real_session_factory, + user=user, + app_id="app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = False + repository._session_factory = MagicMock(return_value=session_context) + return repository + + +def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution: + return WorkflowExecution.new( + id_=execution_id, + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0.0", + graph={"nodes": [], "edges": []}, + inputs={"query": "hello"}, + started_at=started_at, + ) + + +def test_save_uses_execution_started_at_when_record_does_not_exist(): + session = MagicMock() + session.get.return_value = None + repository = _build_repository_with_mocked_session(session) + + started_at = datetime(2026, 1, 1, 12, 0, 0) + execution = _build_execution(execution_id=str(uuid4()), started_at=started_at) + + repository.save(execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == started_at + session.commit.assert_called_once() + + +def test_save_preserves_existing_created_at_when_record_already_exists(): + session = MagicMock() + repository = _build_repository_with_mocked_session(session) + + execution_id = str(uuid4()) + existing_created_at = datetime(2026, 1, 1, 12, 0, 0) + existing_run = WorkflowRun() + existing_run.id = execution_id + existing_run.tenant_id = repository._tenant_id + existing_run.created_at = existing_created_at + session.get.return_value = existing_run + + execution = _build_execution( + execution_id=execution_id, + started_at=datetime(2026, 1, 1, 12, 30, 0), + ) + + repository.save(execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == existing_created_at + session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/trigger/__init__.py b/api/tests/unit_tests/core/trigger/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/conftest.py b/api/tests/unit_tests/core/trigger/conftest.py new file mode 100644 index 0000000000..d9da80a8b7 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/conftest.py @@ -0,0 +1,93 @@ +"""Shared factory helpers for core.trigger test suite.""" + +from __future__ import annotations + +from typing import Any + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.entities.entities import ( + EventEntity, + EventIdentity, + EventParameter, + OAuthSchema, + Subscription, + SubscriptionConstructor, + TriggerProviderEntity, + TriggerProviderIdentity, +) +from core.trigger.provider import PluginTriggerProviderController +from models.provider_ids import TriggerProviderID + +# Valid format for TriggerProviderID: org/plugin/provider +VALID_PROVIDER_ID = "testorg/testplugin/testprovider" + + +def i18n(text: str = "test") -> I18nObject: + return I18nObject(en_US=text, zh_Hans=text) + + +def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity: + return EventEntity( + identity=EventIdentity(author="a", name=name, label=i18n(name)), + description=i18n(name), + parameters=parameters or [], + ) + + +def make_provider_entity( + name: str = "test_provider", + events: list[EventEntity] | None = None, + constructor: SubscriptionConstructor | None = None, + subscription_schema: list[ProviderConfig] | None = None, + icon: str | None = "icon.png", + icon_dark: str | None = None, +) -> TriggerProviderEntity: + return TriggerProviderEntity( + identity=TriggerProviderIdentity( + author="a", + name=name, + label=i18n(name), + description=i18n(name), + icon=icon, + icon_dark=icon_dark, + ), + events=events if events is not None else [make_event()], + subscription_constructor=constructor, + subscription_schema=subscription_schema or [], + ) + + +def make_controller( + entity: TriggerProviderEntity | None = None, + tenant_id: str = "tenant-1", + provider_id: str = VALID_PROVIDER_ID, +) -> PluginTriggerProviderController: + return PluginTriggerProviderController( + entity=entity or make_provider_entity(), + plugin_id="plugin-1", + plugin_unique_identifier="uid-1", + provider_id=TriggerProviderID(provider_id), + tenant_id=tenant_id, + ) + + +def make_subscription(**overrides: Any) -> Subscription: + defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}} + defaults.update(overrides) + return Subscription(**defaults) + + +def make_provider_config( + name: str = "api_key", required: bool = True, config_type: str = "secret-input" +) -> ProviderConfig: + return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required) + + +def make_constructor( + credentials_schema: list[ProviderConfig] | None = None, + oauth_schema: OAuthSchema | None = None, +) -> SubscriptionConstructor: + return SubscriptionConstructor( + parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema + ) diff --git a/api/tests/unit_tests/core/trigger/debug/__init__.py b/api/tests/unit_tests/core/trigger/debug/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py new file mode 100644 index 0000000000..d557c20f5e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py @@ -0,0 +1,93 @@ +""" +Tests for core.trigger.debug.event_bus.TriggerDebugEventBus. + +Covers: Lua-script dispatch/poll with Redis error resilience. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from redis import RedisError + +from core.trigger.debug.event_bus import TriggerDebugEventBus +from core.trigger.debug.events import PluginTriggerDebugEvent + + +class TestDispatch: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_dispatch_count(self, mock_redis): + mock_redis.eval.return_value = 3 + event = MagicMock() + event.model_dump_json.return_value = '{"test": true}' + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 3 + mock_redis.eval.assert_called_once() + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_zero(self, mock_redis): + mock_redis.eval.side_effect = RedisError("connection lost") + event = MagicMock() + event.model_dump_json.return_value = "{}" + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 0 + + +class TestPoll: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_deserialized_event(self, mock_redis): + event_json = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ).model_dump_json() + mock_redis.eval.return_value = event_json + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is not None + assert result.name == "push" + + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_none_when_no_event(self, mock_redis): + mock_redis.eval.return_value = None + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_none(self, mock_redis): + mock_redis.eval.side_effect = RedisError("timeout") + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py new file mode 100644 index 0000000000..331bcd6c25 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -0,0 +1,276 @@ +""" +Tests for core.trigger.debug.event_selectors. + +Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory, +and select_trigger_debug_events orchestrator. +""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.debug.event_selectors import ( + PluginTriggerDebugEventPoller, + ScheduleTriggerDebugEventPoller, + WebhookTriggerDebugEventPoller, + create_event_poller, + select_trigger_debug_events, +) +from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent +from dify_graph.enums import NodeType +from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID + + +def _make_poller_args(node_config: dict | None = None) -> dict: + return { + "tenant_id": "t1", + "user_id": "u1", + "app_id": "a1", + "node_config": node_config or {"data": {}}, + "node_id": "n1", + } + + +def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict: + """Valid node config for TriggerEventNodeData.model_validate.""" + return { + "data": { + "title": "test", + "plugin_id": "org/testplugin", + "provider_id": provider_id, + "event_name": "push", + "subscription_id": "s1", + "plugin_unique_identifier": "uid-1", + } + } + + +class TestPluginTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_workflow_args_on_success(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={"repo": "dify"}, + cancelled=False, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"repo": "dify"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_invoke_cancelled(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={}, + cancelled=True, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + +class TestWebhookTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_uses_inputs_directly_when_present(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"inputs": {"key": "val"}, "webhook_data": {}}, + ) + mock_bus.poll.return_value = event + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"key": "val"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_falls_back_to_webhook_data(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"webhook_data": {"body": "raw"}}, + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc: + mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"} + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"parsed": "data"} + mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"}) + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + +class TestScheduleTriggerDebugEventPoller: + def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime): + """Set up mocks and create a schedule poller.""" + mock_redis.get.return_value = None + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + return ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 12, 0, 0) + next_run = datetime(2025, 1, 1, 13, 0, 0) # future + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 14, 0, 0) + next_run = datetime(2025, 1, 1, 12, 0, 0) # past + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + mock_redis.delete.assert_called_once() + + +class TestCreateEventPoller: + def _workflow_with_node(self, node_type: NodeType): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = node_type + return wf + + def test_creates_plugin_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, PluginTriggerDebugEventPoller) + + def test_creates_webhook_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, WebhookTriggerDebugEventPoller) + + def test_creates_schedule_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, ScheduleTriggerDebugEventPoller) + + def test_raises_for_unknown_type(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.START + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + def test_raises_when_node_config_missing(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = None + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + +class TestSelectTriggerDebugEvents: + def test_returns_first_non_none_event(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll: + expected = MagicMock() + mock_poll.return_value = expected + + result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"]) + + assert result is expected + + def test_returns_none_when_no_events(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None): + result = select_trigger_debug_events(wf, app_model, "u1", ["n1"]) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/test_provider.py b/api/tests/unit_tests/core/trigger/test_provider.py new file mode 100644 index 0000000000..3c2f297e90 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_provider.py @@ -0,0 +1,332 @@ +""" +Tests for core.trigger.provider.PluginTriggerProviderController. + +Covers: to_api_entity creation-method logic, credential validation pipeline, +schema resolution by type, event lookup, dispatch/invoke/subscribe delegation. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.entities.entities import ( + EventParameter, + EventParameterType, + OAuthSchema, + TriggerCreationMethod, +) +from core.trigger.errors import TriggerProviderCredentialValidationError +from tests.unit_tests.core.trigger.conftest import ( + i18n, + make_constructor, + make_controller, + make_event, + make_provider_config, + make_provider_entity, + make_subscription, +) + +ICON_URL = "https://cdn/icon.png" + + +class TestToApiEntity: + @patch("core.trigger.provider.PluginService") + def test_includes_icons_when_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png")) + + api = ctrl.to_api_entity() + + assert api.icon == ICON_URL + assert api.icon_dark == ICON_URL + + @patch("core.trigger.provider.PluginService") + def test_icons_none_when_absent(self, mock_plugin_svc): + ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None)) + + api = ctrl.to_api_entity() + + assert api.icon is None + assert api.icon_dark is None + mock_plugin_svc.get_plugin_icon_url.assert_not_called() + + @patch("core.trigger.provider.PluginService") + def test_manual_only_without_schemas(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + api = ctrl.to_api_entity() + + assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL] + + @patch("core.trigger.provider.PluginService") + def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.OAUTH in api.supported_creation_methods + assert TriggerCreationMethod.MANUAL in api.supported_creation_methods + + @patch("core.trigger.provider.PluginService") + def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.APIKEY in api.supported_creation_methods + + +class TestGetEvent: + def test_returns_matching_event(self): + evt = make_event("push") + ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")])) + + assert ctrl.get_event("push") is evt + + def test_returns_none_for_unknown(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event("nonexistent") is None + + +class TestGetSubscriptionDefaultProperties: + def test_returns_defaults_skipping_none(self): + config1 = make_provider_config("key1") + config1.default = "val1" + config2 = make_provider_config("key2") + config2.default = None + ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2])) + + props = ctrl.get_subscription_default_properties() + + assert props == {"key1": "val1"} + + +class TestValidateCredentials: + def test_raises_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + with pytest.raises(ValueError, match="Subscription constructor not found"): + ctrl.validate_credentials("u1", {"key": "val"}) + + def test_raises_for_missing_required_field(self): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + + with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"): + ctrl.validate_credentials("u1", {}) + + @patch("core.trigger.provider.PluginTriggerClient") + def test_passes_with_valid_credentials(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = True + + ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise + + @patch("core.trigger.provider.PluginTriggerClient") + def test_raises_when_plugin_rejects(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = None + + with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"): + ctrl.validate_credentials("u1", {"api_key": "bad"}) + + +class TestGetSupportedCredentialTypes: + def test_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_supported_credential_types() == [] + + def test_oauth_only(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY not in types + + def test_apikey_only(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.API_KEY in types + assert CredentialType.OAUTH2 not in types + + def test_both(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")]) + ctrl = make_controller( + entity=make_provider_entity( + constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth) + ) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY in types + + +class TestGetCredentialsSchema: + def test_returns_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_credentials_schema(CredentialType.API_KEY) == [] + + def test_returns_apikey_credentials(self): + cfg = make_provider_config("token") + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg]))) + + result = ctrl.get_credentials_schema(CredentialType.API_KEY) + + assert len(result) == 1 + assert result[0].name == "token" + + def test_returns_oauth_credentials(self): + oauth_cred = make_provider_config("oauth_token") + oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + result = ctrl.get_credentials_schema(CredentialType.OAUTH2) + + assert len(result) == 1 + assert result[0].name == "oauth_token" + + def test_unauthorized_returns_empty(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == [] + + def test_invalid_type_raises(self): + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor())) + with pytest.raises(ValueError, match="Invalid credential type"): + ctrl.get_credentials_schema("bogus_type") + + +class TestGetEventParameters: + def test_returns_params_for_known_event(self): + param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING) + evt = make_event("push", parameters=[param]) + ctrl = make_controller(entity=make_provider_entity(events=[evt])) + + result = ctrl.get_event_parameters("push") + + assert "branch" in result + assert result["branch"].name == "branch" + + def test_returns_empty_for_unknown_event(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event_parameters("nonexistent") == {} + + +class TestDispatch: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.dispatch_event.return_value = expected + + result = ctrl.dispatch( + request=MagicMock(), + subscription=make_subscription(), + credentials={"k": "v"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + mock_client.return_value.dispatch_event.assert_called_once() + + +class TestInvokeTriggerEvent: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.invoke_trigger_event.return_value = expected + + result = ctrl.invoke_trigger_event( + user_id="u1", + event_name="push", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + subscription=make_subscription(), + request=MagicMock(), + payload={}, + ) + + assert result is expected + + +class TestSubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_subscription(self, mock_client): + ctrl = make_controller() + mock_client.return_value.subscribe.return_value.subscription = { + "expires_at": 123, + "endpoint": "https://e", + "properties": {}, + } + + result = ctrl.subscribe_trigger( + user_id="u1", + endpoint="https://e", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.endpoint == "https://e" + + +class TestUnsubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_result(self, mock_client): + ctrl = make_controller() + mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"} + + result = ctrl.unsubscribe_trigger( + user_id="u1", + subscription=make_subscription(), + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.success is True + + +class TestRefreshTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_uses_system_user_id(self, mock_client): + ctrl = make_controller() + mock_client.return_value.refresh.return_value.subscription = { + "expires_at": 456, + "endpoint": "https://e", + "properties": {}, + } + + ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY) + + call_kwargs = mock_client.return_value.refresh.call_args[1] + assert call_kwargs["user_id"] == "system" diff --git a/api/tests/unit_tests/core/trigger/test_trigger_manager.py b/api/tests/unit_tests/core/trigger/test_trigger_manager.py new file mode 100644 index 0000000000..612be25ec9 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_trigger_manager.py @@ -0,0 +1,307 @@ +""" +Tests for core.trigger.trigger_manager.TriggerManager. + +Covers: icon URL construction, provider listing with error resilience, +double-check lock caching, error translation, EventIgnoreError -> cancelled, +and delegation to provider controller. +""" + +from __future__ import annotations + +from threading import Lock +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError +from core.trigger.errors import EventIgnoreError +from core.trigger.trigger_manager import TriggerManager +from models.provider_ids import TriggerProviderID +from tests.unit_tests.core.trigger.conftest import ( + VALID_PROVIDER_ID, + make_controller, + make_provider_entity, + make_subscription, +) + +PID = TriggerProviderID(VALID_PROVIDER_ID) +PID_STR = str(PID) + + +class TestGetTriggerPluginIcon: + @patch("core.trigger.trigger_manager.dify_config") + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_builds_correct_url(self, mock_client, mock_config): + mock_config.CONSOLE_API_URL = "https://console.example.com" + provider = MagicMock() + provider.declaration.identity.icon = "my-icon.svg" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID) + + assert "tenant_id=tenant-1" in url + assert "filename=my-icon.svg" in url + assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon") + + +class TestListPluginTriggerProviders: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_wraps_entities_into_controllers(self, mock_client): + entity = MagicMock() + entity.declaration = make_provider_entity("p1") + entity.plugin_id = "plugin-1" + entity.plugin_unique_identifier = "uid-1" + entity.provider = VALID_PROVIDER_ID + mock_client.return_value.fetch_trigger_providers.return_value = [entity] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "plugin-1" + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_skips_failing_providers(self, mock_client): + good = MagicMock() + good.declaration = make_provider_entity("good") + good.plugin_id = "good-plugin" + good.plugin_unique_identifier = "uid-good" + good.provider = VALID_PROVIDER_ID + + bad = MagicMock() + bad.declaration = make_provider_entity("bad") + bad.plugin_id = "bad-plugin" + bad.plugin_unique_identifier = "uid-bad" + bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation + + mock_client.return_value.fetch_trigger_providers.return_value = [bad, good] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "good-plugin" + + +class TestGetTriggerProvider: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_initializes_context_on_first_call(self, mock_ctx, mock_client): + # get() called 3 times: (1) try block, (2) after set, (3) under lock + mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + mock_ctx.plugin_trigger_providers.set.assert_called_once_with({}) + mock_ctx.plugin_trigger_providers_lock.set.assert_called_once() + assert result is not None + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_returns_cached_without_fetch(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached} + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.side_effect = [ + {}, # first check misses + {PID_STR: cached}, # under-lock check hits + ] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client): + cache: dict = {} + mock_ctx.plugin_trigger_providers.get.return_value = cache + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is not None + assert PID_STR in cache + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_none_fetch_raises_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.return_value = None + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone") + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error") + + with pytest.raises(PluginDaemonError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + +class TestListAllTriggerProviders: + @patch.object(TriggerManager, "list_plugin_trigger_providers") + def test_delegates_to_list_plugin(self, mock_list): + expected = [make_controller()] + mock_list.return_value = expected + + assert TriggerManager.list_all_trigger_providers("t1") is expected + mock_list.assert_called_once_with("t1") + + +class TestListTriggersByProvider: + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_provider_events(self, mock_get): + ctrl = make_controller() + mock_get.return_value = ctrl + + result = TriggerManager.list_triggers_by_provider("t1", PID) + + assert result == ctrl.get_events() + + +class TestInvokeTriggerEvent: + def _args(self): + return { + "tenant_id": "t1", + "user_id": "u1", + "provider_id": PID, + "event_name": "on_push", + "parameters": {"branch": "main"}, + "credentials": {"token": "abc"}, + "credential_type": CredentialType.API_KEY, + "subscription": make_subscription(), + "request": MagicMock(), + "payload": {"action": "push"}, + } + + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_invoke_response(self, mock_get): + ctrl = MagicMock() + expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False) + ctrl.invoke_trigger_event.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result is expected + assert result.cancelled is False + + @patch.object(TriggerManager, "get_trigger_provider") + def test_event_ignore_returns_cancelled(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip") + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result.cancelled is True + assert result.variables == {} + + @patch.object(TriggerManager, "get_trigger_provider") + def test_other_errors_propagate(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = RuntimeError("boom") + mock_get.return_value = ctrl + + with pytest.raises(RuntimeError, match="boom"): + TriggerManager.invoke_trigger_event(**self._args()) + + +class TestSubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.subscribe_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.subscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + endpoint="https://hook.test", + parameters={"f": "all"}, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + ctrl.subscribe_trigger.assert_called_once() + + +class TestUnsubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = MagicMock() + ctrl.unsubscribe_trigger.return_value = expected + mock_get.return_value = ctrl + sub = make_subscription() + + result = TriggerManager.unsubscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + subscription=sub, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + + +class TestRefreshTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.refresh_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.refresh_trigger( + tenant_id="t1", + provider_id=PID, + subscription=make_subscription(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected diff --git a/api/tests/unit_tests/core/trigger/utils/__init__.py b/api/tests/unit_tests/core/trigger/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py new file mode 100644 index 0000000000..8804526e2e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py @@ -0,0 +1,62 @@ +"""Tests for core.trigger.utils.encryption — masking logic and cache key generation.""" + +from __future__ import annotations + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.utils.encryption import ( + TriggerProviderCredentialsCache, + TriggerProviderOAuthClientParamsCache, + TriggerProviderPropertiesCache, + masked_credentials, +) + + +def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig: + return ProviderConfig( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + type=field_type, + ) + + +class TestMaskedCredentials: + def test_short_secret_fully_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "ab"}) + assert result["key"] == "**" + + def test_long_secret_partially_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "abcdef"}) + assert result["key"].startswith("ab") + assert result["key"].endswith("ef") + assert "**" in result["key"] + + def test_non_secret_field_unchanged(self): + schema = [_make_schema("host", "text-input")] + result = masked_credentials(schema, {"host": "example.com"}) + assert result["host"] == "example.com" + + def test_unknown_key_passes_through(self): + result = masked_credentials([], {"unknown": "value"}) + assert result["unknown"] == "value" + + +class TestCacheKeyGeneration: + def test_credentials_cache_key_contains_ids(self): + cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "c1" in cache.cache_key + + def test_oauth_client_cache_key_contains_ids(self): + cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + + def test_properties_cache_key_contains_ids(self): + cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "s1" in cache.cache_key diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py new file mode 100644 index 0000000000..e5879aea0a --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py @@ -0,0 +1,31 @@ +"""Tests for core.trigger.utils.endpoint — URL generation.""" + +from __future__ import annotations + +from unittest.mock import patch + +from yarl import URL + +from core.trigger.utils import endpoint + + +class TestGeneratePluginTriggerEndpointUrl: + def test_builds_correct_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123") + + assert url == "https://api.example.com/triggers/plugin/endpoint-123" + + +class TestGenerateWebhookTriggerEndpoint: + def test_non_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False) + + assert url == "https://api.example.com/triggers/webhook/sub-456" + + def test_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True) + + assert url == "https://api.example.com/triggers/webhook-debug/sub-456" diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py new file mode 100644 index 0000000000..4fa202b164 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py @@ -0,0 +1,23 @@ +"""Tests for core.trigger.utils.locks — Redis lock key builders.""" + +from __future__ import annotations + +from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys + + +class TestBuildTriggerRefreshLockKey: + def test_correct_format(self): + key = build_trigger_refresh_lock_key("tenant-1", "sub-1") + + assert key == "trigger_provider_refresh_lock:tenant-1_sub-1" + + +class TestBuildTriggerRefreshLockKeys: + def test_maps_over_pairs(self): + pairs = [("t1", "s1"), ("t2", "s2")] + + keys = build_trigger_refresh_lock_keys(pairs) + + assert len(keys) == 2 + assert keys[0] == "trigger_provider_refresh_lock:t1_s1" + assert keys[1] == "trigger_provider_refresh_lock:t2_s2" diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 0df4927697..22792eb5b3 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch import pytest +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.variables.variables import StringVariable class StubCoordinator: @@ -278,3 +280,17 @@ class TestGraphRuntimeState: assert restored_execution.started is True assert new_stub.state == "configured" + + def test_snapshot_restore_preserves_updated_conversation_variable(self): + variable_pool = VariablePool( + conversation_variables=[StringVariable(name="session_name", value="before")], + ) + variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + snapshot = state.dumps() + restored = GraphRuntimeState.from_snapshot(snapshot) + + restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) + assert restored_value is not None + assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index eb449e6d75..8c8e5977c8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,15 +2,7 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -import sys -from pathlib import Path - from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY - -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) - from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 3f458e9de9..34e714a227 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -22,6 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode from dify_graph.nodes.llm import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor import ParameterExtractorNode +from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode from dify_graph.nodes.template_transform import TemplateTransformNode from dify_graph.nodes.template_transform.template_renderer import ( @@ -65,11 +66,19 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + # LLM-like nodes now require an http_client; provide a mock by default for tests. + kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + # Provide default tool_file_manager_factory for ToolNode subclasses + from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + + if isinstance(self, _ToolNode): + kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + super().__init__( id=id, config=config, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 84d1444585..693cdf9276 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -3,14 +3,8 @@ Simple test to validate the auto-mock system without external dependencies. """ import sys -from pathlib import Path from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY - -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) - from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index e194d66ee3..e929d652fd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -8,7 +8,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.nodes.knowledge_retrieval.entities import ( + Condition, KnowledgeRetrievalNodeData, + MetadataFilteringCondition, MultipleRetrievalConfig, RerankingModelConfig, SingleRetrievalConfig, @@ -110,7 +112,6 @@ class TestKnowledgeRetrievalNode: # Assert assert node.id == node_id assert node._rag_retrieval == mock_rag_retrieval - assert node._llm_file_saver is not None def test_run_with_no_query_or_attachment( self, @@ -205,6 +206,7 @@ class TestKnowledgeRetrievalNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert "result" in result.outputs assert mock_rag_retrieval.knowledge_retrieval.called + mock_source.model_dump.assert_called_once_with(by_alias=True) def test_run_with_query_variable_multiple_mode( self, @@ -592,3 +594,106 @@ class TestFetchDatasetRetriever: # Assert assert version == "1" + + def test_resolve_metadata_filtering_conditions_templates( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_resolve_metadata_filtering_conditions should expand {{#...#}} and keep numbers/None unchanged.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": { + "title": "Knowledge Retrieval", + "type": "knowledge-retrieval", + "dataset_ids": [str(uuid.uuid4())], + "retrieval_mode": "multiple", + }, + } + # Variable in pool used by template + mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + conditions = MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.query#}}"), + Condition(name="tags", comparison_operator="in", value=["x", "{{#start.query#}}"]), + Condition(name="year", comparison_operator="=", value=2025), + ], + ) + + # Act + resolved = node._resolve_metadata_filtering_conditions(conditions) + + # Assert + assert resolved.logical_operator == "and" + assert resolved.conditions[0].value == "readme" + assert isinstance(resolved.conditions[1].value, list) + assert resolved.conditions[1].value[1] == "readme" + assert resolved.conditions[2].value == 2025 + + def test_fetch_passes_resolved_metadata_conditions( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_fetch_dataset_retriever should pass resolved metadata conditions into request.""" + # Arrange + query = "hi" + variables = {"query": query} + mock_graph_runtime_state.variable_pool.add(["start", "q"], StringSegment(value="readme")) + + node_data = KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + multiple_retrieval_config=MultipleRetrievalConfig( + top_k=4, + score_threshold=0.0, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RerankingModelConfig(provider="cohere", model="rerank-v2"), + ), + metadata_filtering_mode="manual", + metadata_filtering_conditions=MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.q#}}"), + ], + ), + ) + + node_id = str(uuid.uuid4()) + config = {"id": node_id, "data": node_data.model_dump()} + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + mock_rag_retrieval.knowledge_retrieval.return_value = [] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + # Act + node._fetch_dataset_retriever(node_data=node_data, variables=variables) + + # Assert the passed request has resolved value + call_args = mock_rag_retrieval.knowledge_retrieval.call_args + request = call_args[1]["request"] + assert request.metadata_filtering_conditions is not None + assert request.metadata_filtering_conditions.conditions[0].value == "readme" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index a3afd1ed5c..b0f0fd428b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -1,10 +1,10 @@ import uuid from typing import NamedTuple from unittest import mock +from unittest.mock import MagicMock import httpx import pytest -from sqlalchemy import Engine from core.helper import ssrf_proxy from core.tools import signature @@ -44,7 +44,6 @@ class TestFileSaverImpl: ) mock_tool_file.id = _gen_id() mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - mocked_engine = mock.MagicMock(spec=Engine) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) @@ -53,11 +52,12 @@ class TestFileSaverImpl: # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) mocked_sign_file.return_value = mock_signed_url + http_client = MagicMock() storage_file_manager = FileSaverImpl( user_id=user_id, tenant_id=tenant_id, - engine_factory=mocked_engine, + http_client=http_client, ) file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) @@ -87,16 +87,18 @@ class TestFileSaverImpl: status_code=401, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response + file_saver = FileSaverImpl( user_id=_gen_id(), tenant_id=_gen_id(), + http_client=http_client, ) - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) with pytest.raises(httpx.HTTPStatusError) as exc: file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_get.assert_called_once_with(_TEST_URL) + http_client.get.assert_called_once_with(_TEST_URL) assert exc.value.response.status_code == 401 def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): @@ -112,8 +114,10 @@ class TestFileSaverImpl: headers={"Content-Type": mime_type}, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) mock_tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 90308facc3..d56035b6bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -111,6 +111,7 @@ def llm_node( "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -120,6 +121,7 @@ def llm_node( model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node @@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node, mock_file_saver diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 5e20b1e12f..13275d4be6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -16,6 +16,7 @@ from dify_graph.nodes.document_extractor.node import ( _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, + _normalize_docx_zip, ) from dify_graph.variables import ArrayFileSegment from dify_graph.variables.segments import ArrayStringSegment @@ -86,6 +87,38 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s assert "is not an ArrayFileSegment" in result.error +def test_run_empty_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """Empty file list should return SUCCEEDED with empty documents and ArrayStringSegment([]).""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Provide an actual ArrayFileSegment with an empty list + mock_graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(value=[]) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + +def test_run_none_only_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """A file list containing only None (e.g., [None]) should be filtered to [] and succeed.""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Use a Mock to bypass type validation for None entries in the list + afs = Mock(spec=ArrayFileSegment) + afs.value = [None] + mock_graph_runtime_state.variable_pool.get.return_value = afs + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + @pytest.mark.parametrize( ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ @@ -385,3 +418,58 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n" assert expected_manual == result + + +def _make_docx_zip(use_backslash: bool) -> bytes: + """Helper to build a minimal in-memory DOCX zip. + + When use_backslash=True the ZIP entry names use backslash separators + (as produced by Evernote on Windows), otherwise forward slashes are used. + """ + import zipfile + + sep = "\\" if use_backslash else "/" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr("[Content_Types].xml", b"
{userProfile.name}
diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx
index 8ea29e8e45..07b685b8c5 100644
--- a/web/app/account/(commonLayout)/avatar.tsx
+++ b/web/app/account/(commonLayout)/avatar.tsx
@@ -7,12 +7,11 @@ import { useRouter } from 'next/navigation'
import { Fragment } from 'react'
import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils'
-import Avatar from '@/app/components/base/avatar'
+import { Avatar } from '@/app/components/base/avatar'
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
import PremiumBadge from '@/app/components/base/premium-badge'
-import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
-import { useLogout } from '@/service/use-common'
+import { useLogout, useUserProfile } from '@/service/use-common'
export type IAppSelector = {
isMobile: boolean
@@ -21,10 +20,15 @@ export type IAppSelector = {
export default function AppSelector() {
const router = useRouter()
const { t } = useTranslation()
- const { userProfile } = useAppContext()
+ const { data: userProfileResp } = useUserProfile()
+ const userProfile = userProfileResp?.profile
const { isEducationAccount } = useProviderContext()
const { mutateAsync: logout } = useLogout()
+
+ if (!userProfile)
+ return null
+
const handleLogout = async () => {
await logout()
@@ -50,7 +54,7 @@ export default function AppSelector() {
${open && 'bg-components-panel-bg-blur'}
`}
>
-
{member.name}
diff --git a/web/app/components/app/app-access-control/specific-groups-or-members.tsx b/web/app/components/app/app-access-control/specific-groups-or-members.tsx index 8ca817c872..2c0e4b2694 100644 --- a/web/app/components/app/app-access-control/specific-groups-or-members.tsx +++ b/web/app/components/app/app-access-control/specific-groups-or-members.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects } from '@/service/access-control' import useAccessControlStore from '../../../../context/access-control-store' -import Avatar from '../../base/avatar' +import { Avatar } from '../../base/avatar' import Loading from '../../base/loading' import Tooltip from '../../base/tooltip' import AddMemberOrGroupDialog from './add-member-or-group-pop' @@ -106,7 +106,7 @@ function MemberItem({ member }: MemberItemProps) { }, [member, setSpecificMembers, specificMembers]) return ({member.name}
diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx index 0bbed83a99..09a5ff6d07 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx @@ -172,12 +172,8 @@ describe('dataset-config/card-item', () => { const [editButton] = within(card).getAllByRole('button', { hidden: true }) await user.click(editButton) - expect(screen.getByText('Mock settings modal')).toBeInTheDocument() - await waitFor(() => { - expect(screen.getByRole('dialog')).toBeVisible() - }) - - fireEvent.click(screen.getByText('Save changes')) + expect(await screen.findByText('Mock settings modal')).toBeInTheDocument() + fireEvent.click(await screen.findByText('Save changes')) await waitFor(() => { expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' })) @@ -194,7 +190,7 @@ describe('dataset-config/card-item', () => { const card = screen.getByText(dataset.name).closest('.group') as HTMLElement const buttons = within(card).getAllByRole('button', { hidden: true }) - const deleteButton = buttons[buttons.length - 1] + const deleteButton = buttons.at(-1)! expect(deleteButton.className).not.toContain('action-btn-destructive') @@ -233,7 +229,7 @@ describe('dataset-config/card-item', () => { await user.click(editButton) expect(screen.getByText('Mock settings modal')).toBeInTheDocument() - const overlay = Array.from(document.querySelectorAll('[class]')) + const overlay = [...document.querySelectorAll('[class]')] .find(element => element.className.toString().includes('bg-black/30')) expect(overlay).toBeInTheDocument() diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx index d621bb3941..350ede8c96 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.spec.tsx @@ -91,7 +91,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({ })) vi.mock('@/app/components/base/avatar', () => ({ - default: ({ name }: { name: string }) =>