Compare commits

...

84 Commits

Author SHA1 Message Date
Asuka Minato
41c4d0bff0 Merge branch 'main' into copilot/fix-ce7a9189-f70c-4ea9-81ba-3f49303104d6 2025-10-16 20:50:19 +09:00
Xiyuan Chen
06649f6c21 Update email templates to improve clarity and consistency in messagin… (#26970) 2025-10-16 01:42:22 -07:00
Yongtao Huang
8b61f5e9c4 Fix: avoid duplicate response_chunk update in convert_stream_simple_response (#26965)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-16 15:53:07 +08:00
GuanMu
6432898e7a refactor: update TypeScript definitions for custom JSX elements and clean up global declarations in emoji picker (#26985) 2025-10-16 15:51:39 +08:00
Asuka Minato
cced33d068 use deco to avoid current_user (#26077)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-16 15:45:51 +09:00
Xiyuan Chen
bd01af6415 fix: update load balancing configurations with new credential IDs and… (#26900) 2025-10-15 21:15:26 -07:00
wellCh4n
35011b810d feat: run with params from logs (#26787)
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
2025-10-16 11:01:11 +08:00
Xin Zhang
f295c7532c fix plugin installation permissions when using a local pkg (#26822)
Co-authored-by: zhangx1n <zhangxin@dify.ai>
2025-10-16 10:58:28 +08:00
zyssyz123
7065b67d07 add app mode for message (#26876) 2025-10-16 10:19:49 +08:00
GuanMu
c0b50ef61d chore: remove unused icon components and related features from the co… (#26933) 2025-10-15 16:48:02 +08:00
-LAN-
1d8cca4fa2 Fix: check external commands after node completion (#26891)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-15 16:47:43 +08:00
Wu Tianwei
3474c179e6 fix: enhance dataset menu and add service API translations (#26931) 2025-10-15 16:46:46 +08:00
GuanMu
433dad7e1a chore: add type-check script to package.json for TypeScript validation (#26929) 2025-10-15 16:37:46 +08:00
github-actions[bot]
be7ee380bc chore: translate i18n files and update type definitions (#26916)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-15 16:36:39 +08:00
yangzheli
cff5de626b feat(agent): similar to the start node of workflow, agent variables also support drag-and-drop (#26899) 2025-10-15 13:07:51 +08:00
znn
4d8b8f9210 allow editing of hidden inputs in preview (#24370)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2025-10-15 11:19:53 +08:00
Ademílson Tonato
a16ef7e73c refactor: Update Firecrawl to use v2 API (#24734)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-15 10:48:54 +08:00
kenwoodjw
c39dae06d4 fix: workflow token usage (#26723)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-10-15 10:39:51 +08:00
Novice
f906e70f6b chore: remove redundant dependencies (#26907) 2025-10-15 09:55:39 +08:00
lyzno1
5139119307 chore: bump pnpm version (#26905)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
2025-10-15 09:55:05 +08:00
Yusuke Yamada
1b537f904a fix: replace CodeGroup's POST /meta with GET /site (#26886) 2025-10-15 09:43:10 +08:00
-LAN-
556b631c54 Normalize null metadata handling in tool entities (#26890) 2025-10-15 09:42:22 +08:00
NeatGuyCoding
49df9ceaf3 minor fix: test cases for alibabacloud mysql and chinese translations (#26902)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-15 09:41:12 +08:00
Tonlo
92ec1ac27a Fix/remove logo in withoutbrand template (#26882) 2025-10-15 09:40:33 +08:00
-LAN-
e74097afdf Remove unused after_request hooks from console API keys (#26896) 2025-10-15 00:43:11 +08:00
Asuka Minato
8ddc4f2292 example to auto rollback (#26200)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-15 00:42:55 +09:00
非法操作
7b51320346 fix: when create provider credential set the provider record to vaild (#26868)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-14 19:42:48 +08:00
GuanMu
9e39be0770 fix: correct indentation in JSON payloads (#26871) 2025-10-14 19:41:01 +08:00
GuanMu
3e5e87930c feat: add Knip configuration for dead code detection and remove unused icon components (#26758)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 19:06:31 +08:00
hj24
15a5ba67f1 fix: use account id in workflow app log filter (#26811)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 14:32:40 +08:00
github-actions[bot]
9e3b4dc90d chore: translate i18n files and update type definitions (#26859)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
2025-10-14 10:43:28 +08:00
Dhruv Gorasiya
48c42a9fba Weaviate update version (#25447)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-14 10:39:53 +08:00
XlKsyt
0b35bc1ede feat: add Tencent Cloud APM tracing integration (#25657)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:21:17 +08:00
Davide Delbianco
8e01bb40fe fix: Do not show the toggle button for chat input when all input hidden (#26826)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:15:06 +08:00
Guangdong Liu
9d21772820 fix: Validate transfer method in file mapping and improve file input handling (#26848)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:10:31 +08:00
NeatGuyCoding
b745839bdb Feature add test containers mail owner transfer task (#26854)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:01:47 +08:00
Eric Guo
59ad6e02ce Add timeout so any plugin daemon call (including the SSE path) that legitimately takes longer than 5s would right. (#26852) 2025-10-14 09:23:27 +08:00
Guangdong Liu
a3b33cbe28 refactor: streamline database session usage in batch_create_segment_to_index_task (#26795)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-14 09:22:48 +08:00
Davide Delbianco
7b8540281a fix: Chat Opener visibility flickering (#26836) 2025-10-14 09:21:00 +08:00
Asuka Minato
0a6b78f883 Use hook to get userid (#26839)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-14 09:20:37 +08:00
heyszt
56ee8f7d64 fix: files/support-type JSON serialization error (#26842) 2025-10-14 09:20:19 +08:00
Davide Delbianco
3cfcd32876 chore: Fix 25795 (#26823)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-13 17:44:51 +08:00
Davide Delbianco
06dcb55a9d chore: Don't show chat input area scrollbar overflow (#26828) 2025-10-13 17:43:46 +08:00
Ponder
ec6cafd7aa feat: Cache AppQueueManager.is_stopped() to reduce unnecessary Redis … (#26778) 2025-10-13 17:41:16 +08:00
Davide Delbianco
6e9858960d chore: Fix chat-input-area resize (#26824) 2025-10-13 17:36:15 +08:00
minglu7
150a8276b9 fix: avoid closing shared session during embeddings (#26830) 2025-10-13 17:36:00 +08:00
Davide Delbianco
c6a90d4bb3 fix: Don't hide chat streaming loader on '\n' content (#26829) 2025-10-13 17:31:52 +08:00
Davide Delbianco
c71fd7113c chore: Correct padding in embedded chatbot (#26832) 2025-10-13 17:29:47 +08:00
NFish
5fc104a992 Fix/web app permission check (#26821) 2025-10-13 16:17:42 +08:00
fenglin
d1de3cfb94 fix: use enum .value strings in retrieval-setting API to fix JSON serialization error (#26785)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-13 13:01:44 +08:00
屈定
44d36f2460 fix: external knowledge url check ssrf (#26789)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-13 11:19:00 +08:00
Wu Tianwei
9088f151d9 fix: invalid data source list in plugin refresh hook (#26813) 2025-10-13 11:17:46 +08:00
Wu Tianwei
c692962650 fix: update tooltip for chunk structure in knowledge base component (#26808) 2025-10-13 10:44:10 +08:00
Wu Tianwei
f0a60a9000 feat: enhance DataSources component with marketplace plugin integration and search filtering (#26810)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-13 10:43:51 +08:00
AsperforMias
2f50f3fd4b refactor: use libs.login current_user in console controllers (#26745)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-13 10:33:33 +08:00
Asuka Minato
24cd7bbc62 fix RetrievalMethod StrEnum (#26768)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-13 10:29:37 +08:00
Guangdong Liu
d299e75e1b refactor: use dynamic max characters for chunking in extractors (#26782) 2025-10-13 10:22:59 +08:00
yangzheli
f86b6658c9 perf(web): split constant files to improve web performance (#26794) 2025-10-13 10:22:34 +08:00
Asuka Minato
0a56d65581 Issue 23579 (#26777)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-13 10:16:12 +08:00
Yuto Yamada
dfc03bac9f Fix typo: reponse to response (#26792) 2025-10-13 10:04:19 +08:00
dependabot[bot]
81e1376e08 chore(deps): bump opik from 1.7.43 to 1.8.72 in /api (#26804)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 10:00:35 +08:00
dependabot[bot]
f50c85d536 chore(deps-dev): bump knip from 5.64.1 to 5.64.3 in /web (#26802)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 10:00:03 +08:00
dependabot[bot]
5830c69694 chore(deps): bump @lexical/utils from 0.36.2 to 0.37.0 in /web (#26801)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 09:59:31 +08:00
crazywoola
0173496a77 fix: happy-dom version (#26764)
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
2025-10-11 18:59:31 +08:00
lyzno1
30c5b47699 refactor: simplify InlineDeleteConfirm component structure (#26771) 2025-10-11 18:18:18 +08:00
NeatGuyCoding
e3191d4e91 fix enum and type (#26756)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-11 17:46:44 +08:00
lyzno1
a9b3539b90 feat: migrate Python SDK to httpx with async/await support (#26726)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-11 17:45:42 +08:00
github-actions[bot]
5217017e69 chore: translate i18n files and update type definitions (#26763)
Co-authored-by: asukaminato0721 <30024051+asukaminato0721@users.noreply.github.com>
2025-10-11 17:23:40 +08:00
lyzno1
bd5df5cf1c feat: add InlineDeleteConfirm base component (#26762) 2025-10-11 17:33:31 +09:00
Guangdong Liu
456dbfe7d7 feat: add tracking for updated_by and updated_at fields in app models (#26736) 2025-10-11 13:48:57 +08:00
GuanMu
586f210d6e chore: remove unused dependencies for dagre from package.json and pnpm-lock.yaml (#26755) 2025-10-11 13:01:05 +08:00
Maries
275a0f9ddd chore(workflows): update deployment configurations for trigger dev (#26753) 2025-10-11 12:43:09 +08:00
carribean
cbf2ba6cec Feature integrate alibabacloud mysql vector (#25994)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-11 10:47:28 +08:00
Asuka Minato
1bd621f819 remove .value (#26633)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-11 09:08:29 +08:00
Asuka Minato
bb6a331490 change all to httpx (#26119)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-10 23:41:16 +08:00
Asuka Minato
3922ad876f part of add type to orm (#26262)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-10 23:40:54 +08:00
jiangbo721
fdb53fdeb1 fix: Set ApiTool’s do_http_request to do not retry. (#26721) 2025-10-10 23:39:25 +08:00
GuanMu
3fb5a7bff1 fix: add z-index class to PortalToFollowElemContent for proper layering in dataset extra info component (#26729) 2025-10-10 23:39:13 +08:00
heyszt
6157c67cfe fix: sync aliyun icon SVG files (#26719) 2025-10-10 23:38:45 +08:00
GuanMu
fbc745764a chore: update packageManager version in package.json to pnpm@10.18.2 (#26731) 2025-10-10 23:37:40 +08:00
Arno Ren
78f09801b5 fix: #26668 restore manual tool parameter values (#26733)
Co-authored-by: renzeyu1 <renzeyu1@lixiang.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-10 23:37:10 +08:00
copilot-swe-agent[bot]
c8002460ac Complete Resource to MethodView migration feasibility study
Co-authored-by: asukaminato0721 <30024051+asukaminato0721@users.noreply.github.com>
2025-09-25 16:45:40 +00:00
copilot-swe-agent[bot]
5a9045cd63 Proof-of-concept: Resource to MethodView conversion research
Co-authored-by: asukaminato0721 <30024051+asukaminato0721@users.noreply.github.com>
2025-09-25 16:41:51 +00:00
copilot-swe-agent[bot]
659ff5c69e Initial plan 2025-09-25 16:26:25 +00:00
580 changed files with 12281 additions and 6378 deletions

View File

@@ -39,25 +39,11 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@@ -93,3 +79,19 @@ jobs:
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@@ -30,6 +30,8 @@ jobs:
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union

View File

@@ -4,8 +4,7 @@ on:
push:
branches:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "deploy/**"
- "build/**"
- "release/e-*"
- "hotfix/**"

View File

@@ -1,4 +1,4 @@
name: Deploy RAG Dev
name: Deploy Trigger Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
- "deploy/trigger-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -1,6 +1,7 @@
#!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml

View File

@@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306
ALIBABACLOUD_MYSQL_USER=root
ALIBABACLOUD_MYSQL_PASSWORD=root
ALIBABACLOUD_MYSQL_DATABASE=dify
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
ALIBABACLOUD_MYSQL_HNSW_M=6
# openGauss configuration
OPENGAUSS_HOST=127.0.0.1
OPENGAUSS_PORT=6600

View File

@@ -1521,6 +1521,14 @@ def transform_datasource_credentials():
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
if not firecrawl_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
@@ -1576,6 +1584,14 @@ def transform_datasource_credentials():
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
if not jina_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")

View File

@@ -189,6 +189,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key",
)
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(

View File

@@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
@@ -330,6 +331,7 @@ class MiddlewareConfig(
ClickzettaConfig,
HuaweiCloudConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,

View File

@@ -0,0 +1,54 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AlibabaCloudMySQLConfig(BaseSettings):
"""
Configuration settings for AlibabaCloud MySQL vector database
"""
ALIBABACLOUD_MYSQL_HOST: str = Field(
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
default="localhost",
)
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
default=3306,
)
ALIBABACLOUD_MYSQL_USER: str = Field(
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
default="root",
)
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
default="",
)
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
default="dify",
)
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the connection pool",
default=5,
)
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
default="utf8mb4",
)
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
"(e.g., 'cosine', 'euclidean')",
default="cosine",
)
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
default=6,
)

View File

@@ -1,23 +1,24 @@
from enum import Enum
from enum import StrEnum
from typing import Literal
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AuthMethod(StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
"""
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: str | None = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,

View File

@@ -1,5 +1,4 @@
import flask_restx
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
@@ -8,12 +7,12 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from . import api, console_ns
from .wraps import account_initialization_required, setup_required
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
@@ -57,7 +56,9 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@@ -66,13 +67,12 @@ class BaseApiKeyListResource(Resource):
return {"items": keys}
@marshal_with(api_key_fields)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
@@ -89,7 +89,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@@ -108,7 +108,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
@@ -152,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@@ -173,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@@ -202,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
@@ -223,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@@ -1,15 +1,14 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_redis import redis_client
@@ -42,10 +41,8 @@ class AnnotationReplyActionApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
@@ -69,10 +66,8 @@ class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
@@ -98,10 +93,8 @@ class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
@@ -124,10 +117,8 @@ class AnnotationReplyActionStatusApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
@@ -159,10 +150,8 @@ class AnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
@@ -198,10 +187,8 @@ class AnnotationApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
@@ -213,10 +200,8 @@ class AnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# Use request.args.getlist to get annotation_ids array directly
@@ -249,10 +234,8 @@ class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
@@ -271,11 +254,9 @@ class AnnotationUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
@@ -288,10 +269,8 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
@@ -310,10 +289,8 @@ class AnnotationBatchImportApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# check file
if "file" not in request.files:
@@ -341,10 +318,8 @@ class AnnotationBatchImportStatusApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key)
@@ -376,10 +351,8 @@ class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id)

View File

@@ -1,7 +1,5 @@
import uuid
from typing import cast
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -12,15 +10,16 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import Account, App
from models import App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@@ -56,6 +55,7 @@ class AppListApi(Resource):
@enterprise_license_required
def get(self):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value):
try:
@@ -90,7 +90,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@@ -129,8 +129,10 @@ class AppListApi(Resource):
@account_initialization_required
@marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
@@ -140,19 +142,11 @@ class AppListApi(Resource):
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
app = app_service.create_app(current_tenant_id, args, current_user)
return app, 201
@@ -205,13 +199,10 @@ class AppApi(Resource):
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site)
def put(self, app_model):
"""Update app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
@@ -248,12 +239,9 @@ class AppApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_model):
"""Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService()
app_service.delete_app(app_model)
@@ -283,12 +271,12 @@ class AppCopyApi(Resource):
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
@@ -301,10 +289,9 @@ class AppCopyApi(Resource):
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
@@ -340,12 +327,9 @@ class AppExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_model):
"""Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
@@ -371,11 +355,8 @@ class AppNameApi(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
@@ -408,11 +389,8 @@ class AppIconApi(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
@@ -441,11 +419,8 @@ class AppSiteStatus(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
@@ -475,6 +450,7 @@ class AppApiStatus(Resource):
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -520,10 +496,9 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id):
# add app trace
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json")

View File

@@ -1,20 +1,16 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
@@ -30,11 +26,10 @@ class AppImportApi(Resource):
@account_initialization_required
@marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
@@ -51,7 +46,7 @@ class AppImportApi(Resource):
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_app(
account=account,
import_mode=args["mode"],
@@ -70,9 +65,9 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@@ -83,21 +78,21 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
@edit_permission_required
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@@ -109,10 +104,8 @@ class AppImportCheckDependenciesApi(Resource):
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
@edit_permission_required
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@@ -2,7 +2,7 @@ import logging
from flask import request
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api, console_ns
@@ -15,7 +15,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -151,13 +151,8 @@ class ChatMessageApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")

View File

@@ -1,17 +1,16 @@
from datetime import datetime
import pytz # pip install pytz
import pytz
import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@@ -22,8 +21,8 @@ from fields.conversation_fields import (
)
from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@@ -57,9 +56,9 @@ class CompletionConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -84,6 +83,7 @@ class CompletionConversationApi(Resource):
)
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -137,9 +137,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@@ -154,14 +153,12 @@ class CompletionConversationDetailApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -206,9 +203,9 @@ class ChatConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -260,6 +257,7 @@ class ChatConversationApi(Resource):
)
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -309,7 +307,7 @@ class ChatConversationApi(Resource):
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]:
case "created_at":
@@ -341,9 +339,8 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@@ -358,14 +355,12 @@ class ChatConversationDetailApi(Resource):
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -374,6 +369,7 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)

View File

@@ -1,6 +1,5 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
@@ -48,11 +47,11 @@ class RuleGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
@@ -99,11 +98,11 @@ class RuleCodeGenerateApi(Resource):
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
@@ -144,11 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource):
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
)
@@ -198,6 +197,7 @@ class InstructionGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
@@ -222,21 +222,21 @@ class InstructionGenerateApi(Resource):
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
@@ -245,7 +245,7 @@ class InstructionGenerateApi(Resource):
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
@@ -254,7 +254,7 @@ class InstructionGenerateApi(Resource):
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],

View File

@@ -1,16 +1,15 @@
import json
from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
@@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
@api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
@@ -48,14 +47,14 @@ class AppMCPServerController(Resource):
)
@api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@login_required
@setup_required
@marshal_with(app_server_fields)
@edit_permission_required
def post(self, app_model):
if not current_user.is_editor:
raise NotFound()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
@@ -71,7 +70,7 @@ class AppMCPServerController(Resource):
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
server_code=AppMCPServer.generate_server_code(16),
)
db.session.add(server)
@@ -95,14 +94,13 @@ class AppMCPServerController(Resource):
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_fields)
@edit_permission_required
def put(self, app_model):
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=False, location="json")
@@ -142,13 +140,13 @@ class AppMCPServerRefreshController(Resource):
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
@edit_permission_required
def get(self, server_id):
if not current_user.is_editor:
raise NotFound()
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
)
if not server:

View File

@@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns
from controllers.console.app.error import (
@@ -17,6 +17,7 @@ from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDi
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -26,8 +27,7 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
@@ -56,15 +56,13 @@ class ChatMessageListApi(Resource):
)
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
@@ -154,8 +152,7 @@ class MessageFeedbackApi(Resource):
@login_required
@account_initialization_required
def post(self, app_model):
if current_user is None:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
@@ -211,18 +208,14 @@ class MessageAnnotationApi(Resource):
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@get_app_model
@marshal_with(annotation_fields)
@account_initialization_required
@edit_permission_required
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
@@ -270,6 +263,7 @@ class MessageSuggestedQuestionApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id)
try:
@@ -304,12 +298,12 @@ class MessageApi(Resource):
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(message_detail_fields)
def get(self, app_model, message_id):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

View File

@@ -2,7 +2,6 @@ import json
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
@@ -14,8 +13,8 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
from models.account import Account
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@@ -53,16 +52,14 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
if not isinstance(current_user, Account):
raise Forbidden()
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode),
)
@@ -94,12 +91,12 @@ class ModelConfigResource(Resource):
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
@@ -133,7 +130,7 @@ class ModelConfigResource(Resource):
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
@@ -141,7 +138,7 @@ class ModelConfigResource(Resource):
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
@@ -172,6 +169,8 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)

View File

@@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@@ -9,8 +8,8 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Account, Site
from libs.login import current_account_with_tenant, login_required
from models import Site
def parse_app_site_args():
@@ -76,9 +75,10 @@ class AppSite(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
@@ -107,8 +107,6 @@ class AppSite(Resource):
if value is not None:
setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()
@@ -131,6 +129,8 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -140,8 +140,6 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound
site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()

View File

@@ -4,7 +4,6 @@ from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@@ -13,7 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
@@ -37,7 +36,7 @@ class DailyMessageStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -52,7 +51,8 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -109,13 +109,13 @@ class DailyConversationStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource):
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if args["start"]:
@@ -175,7 +175,7 @@ class DailyTerminalsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -190,8 +190,8 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -247,7 +247,7 @@ class DailyTokenCostStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -263,8 +263,8 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -322,7 +322,7 @@ class AverageSessionInteractionStatistic(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -345,8 +345,8 @@ FROM
WHERE
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -413,7 +413,7 @@ class UserSatisfactionRateStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -432,8 +432,8 @@ LEFT JOIN
WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -494,7 +494,7 @@ class AverageResponseTimeStatistic(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -509,8 +509,8 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -566,7 +566,7 @@ class TokensPerSecondStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -584,8 +584,8 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@@ -12,7 +12,7 @@ import services
from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -25,10 +25,10 @@ from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from models.account import Account
from models.model import AppMode
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
@@ -69,15 +69,11 @@ class DraftWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
@@ -109,14 +105,12 @@ class DraftWorkflowApi(Resource):
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required
def post(self, app_model: App):
"""
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
@@ -148,10 +142,6 @@ class DraftWorkflowApi(Resource):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
try:
@@ -205,17 +195,12 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App):
"""
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
@@ -270,16 +255,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@@ -322,16 +303,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@@ -374,17 +351,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@@ -427,17 +399,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@@ -479,17 +446,12 @@ class DraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Run draft workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
@@ -525,17 +487,11 @@ class WorkflowTaskStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, task_id: str):
"""
Stop workflow task
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
@@ -567,17 +523,12 @@ class DraftWorkflowNodeRunApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields)
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
@@ -621,17 +572,11 @@ class PublishedWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get published workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# fetch published workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model)
@@ -643,16 +588,12 @@ class PublishedWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Publish workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
@@ -674,8 +615,12 @@ class PublishedWorkflowApi(Resource):
marked_comment=args.marked_comment or "",
)
app_model.workflow_id = workflow.id
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
# Update app_model within the same session to ensure atomicity
app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at)
@@ -697,17 +642,11 @@ class DefaultBlockConfigsApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_configs()
@@ -724,16 +663,11 @@ class DefaultBlockConfigApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App, block_type: str):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
@@ -764,17 +698,14 @@ class ConvertToWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
@edit_permission_required
def post(self, app_model: App):
"""
Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
if request.data:
parser = reqparse.RequestParser()
@@ -807,15 +738,12 @@ class PublishedAllWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get published workflows
"""
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
@@ -874,16 +802,12 @@ class WorkflowByIdApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
@@ -929,16 +853,11 @@ class WorkflowByIdApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def delete(self, app_model: App, workflow_id: str):
"""
Delete workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
workflow_service = WorkflowService()
# Create a session and manage the transaction

View File

@@ -22,8 +22,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models.account import Account
from models import Account, App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService

View File

@@ -1,6 +1,5 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
@@ -14,7 +13,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields,
)
from libs.helper import uuid_value
from libs.login import login_required
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser
from services.workflow_run_service import WorkflowRunService

View File

@@ -4,7 +4,6 @@ from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse
from controllers.console import api, console_ns
@@ -12,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
@@ -29,7 +28,7 @@ class WorkflowDailyRunsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -47,9 +46,9 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -97,7 +96,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -115,9 +114,9 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -165,7 +164,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -183,9 +182,9 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -238,7 +237,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@@ -269,9 +268,9 @@ GROUP BY
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from libs.login import current_account_with_tenant
from models import App, AppMode
from models.account import Account
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
assert isinstance(current_user, Account)
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@@ -7,7 +7,7 @@ from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from models import AccountStatus
from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser()
@@ -103,7 +103,7 @@ class ActivateApi(Resource):
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required
@@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
@login_required
@account_initialization_required
def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
"sources": [
@@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
@@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204

View File

@@ -2,13 +2,12 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api, console_ns
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
@@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
@api.response(403, "Admin privileges required")
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()

View File

@@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError

View File

@@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService

View File

@@ -1,5 +1,3 @@
from typing import cast
import flask_login
from flask import request
from flask_restx import Resource, reqparse
@@ -26,7 +24,7 @@ from controllers.console.error import (
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from models.account import Account
from libs.login import current_account_with_tenant
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountRegisterError
@@ -96,7 +94,8 @@ class LoginApi(Resource):
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account)

View File

@@ -14,8 +14,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
@@ -130,11 +129,11 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@@ -1,16 +1,15 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast
from typing import Concatenate, ParamSpec, TypeVar
import flask_login
from flask import jsonify, request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@@ -116,7 +115,8 @@ class OAuthServerUserAuthorizeApi(Resource):
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user)
current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)

View File

@@ -2,8 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required
from models.model import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -14,17 +13,13 @@ class Subscription(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices")
@@ -34,7 +29,6 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
return BillingService.get_invoices(current_user.email, current_tenant_id)

View File

@@ -1,9 +1,8 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from .. import console_ns
@@ -17,17 +16,17 @@ class ComplianceApi(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")
return BillingService.get_compliance_download_link(
doc_name=args.doc_name,
account_id=current_user.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
ip=ip_address,
device_info=device_info,
)

View File

@@ -3,7 +3,6 @@ from collections.abc import Generator
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
@@ -37,10 +36,12 @@ class DataSourceApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
).all()
@@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
@@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
@@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
)
text_docs = extractor.extract()
@@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
@@ -256,14 +263,14 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],

View File

@@ -1,7 +1,6 @@
from typing import Any, cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@@ -30,10 +29,9 @@ from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@@ -45,6 +43,79 @@ def _validate_name(name: str) -> str:
return name
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
"""
Get supported retrieval methods based on vector database type.
Args:
vector_type: Vector database type, can be None
is_mock: Whether this is a Mock API, affects MILVUS handling
Returns:
Dictionary containing supported retrieval methods
Raises:
ValueError: If vector_type is None or unsupported
"""
if vector_type is None:
raise ValueError("Vector store type is not configured.")
# Define vector database types that only support semantic search
semantic_only_types = {
VectorType.RELYT,
VectorType.TIDB_VECTOR,
VectorType.CHROMA,
VectorType.PGVECTO_RS,
VectorType.VIKINGDB,
VectorType.UPSTASH,
}
# Define vector database types that support all retrieval methods
full_search_types = {
VectorType.QDRANT,
VectorType.WEAVIATE,
VectorType.OPENSEARCH,
VectorType.ANALYTICDB,
VectorType.MYSCALE,
VectorType.ORACLE,
VectorType.ELASTICSEARCH,
VectorType.ELASTICSEARCH_JA,
VectorType.PGVECTOR,
VectorType.VASTBASE,
VectorType.TIDB_ON_QDRANT,
VectorType.LINDORM,
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
VectorType.MATRIXONE,
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
full_methods = {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
if vector_type == VectorType.MILVUS:
return semantic_methods if is_mock else full_methods
if vector_type in semantic_only_types:
return semantic_methods
elif vector_type in full_search_types:
return full_methods
else:
raise ValueError(f"Unsupported vector db type {vector_type}.")
@console_ns.route("/datasets")
class DatasetListApi(Resource):
@api.doc("get_datasets")
@@ -65,6 +136,7 @@ class DatasetListApi(Resource):
@account_initialization_required
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
@@ -73,15 +145,15 @@ class DatasetListApi(Resource):
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -178,6 +250,7 @@ class DatasetListApi(Resource):
required=False,
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@@ -185,11 +258,11 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=cast(Account, current_user),
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
@@ -213,6 +286,7 @@ class DatasetApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -232,7 +306,7 @@ class DatasetApi(Resource):
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -345,6 +419,7 @@ class DatasetApi(Resource):
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
@@ -367,7 +442,7 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
@@ -391,9 +466,10 @@ class DatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
try:
@@ -432,6 +508,7 @@ class DatasetQueryApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -483,15 +560,14 @@ class DatasetIndexingEstimateApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
).all()
if file_details is None:
@@ -500,7 +576,7 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=args["doc_form"],
)
@@ -512,14 +588,14 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@@ -529,13 +605,13 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
@@ -548,7 +624,7 @@ class DatasetIndexingEstimateApi(Resource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
@@ -579,6 +655,7 @@ class DatasetRelatedAppListApi(Resource):
@account_initialization_required
@marshal_with(related_app_list)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -610,11 +687,10 @@ class DatasetIndexingStatusApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
).all()
documents_status = []
for document in documents:
@@ -666,10 +742,9 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return {"items": keys}
@@ -679,12 +754,13 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
)
@@ -697,7 +773,7 @@ class DatasetApiKeyApi(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@@ -717,6 +793,7 @@ class DatasetApiDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
@@ -726,7 +803,7 @@ class DatasetApiDeleteApi(Resource):
key = (
db.session.query(ApiToken)
.where(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
@@ -777,49 +854,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case (
VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@@ -832,48 +867,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required
@account_initialization_required
def get(self, vector_type):
match vector_type:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.LINDORM
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
@@ -908,6 +902,7 @@ class DatasetPermissionUserListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@@ -6,7 +6,6 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -53,9 +52,8 @@ from fields.document_fields import (
document_with_segments_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@@ -65,6 +63,7 @@ logger = logging.getLogger(__name__)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@@ -79,12 +78,13 @@ class DocumentResource(Resource):
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.")
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@@ -112,6 +112,7 @@ class GetProcessRuleApi(Resource):
@login_required
@account_initialization_required
def get(self):
current_user, _ = current_account_with_tenant()
req_data = request.args
document_id = req_data.get("document_id")
@@ -168,6 +169,7 @@ class DatasetDocumentListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
@@ -199,7 +201,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if search:
search = f"%{search}%"
@@ -273,6 +275,7 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -372,6 +375,7 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
@@ -402,7 +406,7 @@ class DatasetInitApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"],
@@ -419,9 +423,9 @@ class DatasetInitApi(Resource):
try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
knowledge_config=knowledge_config,
account=cast(Account, current_user),
account=current_user,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -447,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@@ -475,14 +480,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
try:
estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
[extract_setting],
data_process_rule_dict,
document.doc_form,
@@ -511,6 +516,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@login_required
@account_initialization_required
def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
@@ -530,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
@@ -538,7 +544,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
@@ -546,14 +552,14 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
@@ -563,13 +569,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
@@ -583,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
data_process_rule_dict,
document.doc_form,
@@ -834,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@@ -884,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
@login_required
@account_initialization_required
def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@@ -931,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
@@ -1077,12 +1086,13 @@ class DocumentRenameApi(DocumentResource):
@marshal_with(document_fields)
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
@@ -1102,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
@account_initialization_required
def get(self, dataset_id, document_id):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@@ -1110,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.")
if document.data_source_type != "website_crawl":
raise ValueError("Document is not a website document.")

View File

@@ -1,7 +1,6 @@
import uuid
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
@@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource):
select(DocumentSegment)
.where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
DocumentSegment.tenant_id == current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
@@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
upload_file_id,
dataset_id,
document_id,
current_user.current_tenant_id,
current_tenant_id,
current_user.id,
)
except Exception as e:
@@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
@@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)

View File

@@ -1,7 +1,4 @@
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -10,8 +7,7 @@ from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
@@ -40,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_user.current_tenant_id, search
page, limit, current_tenant_id, search
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
@@ -60,6 +57,7 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
@@ -85,7 +83,7 @@ class ExternalApiTemplateListApi(Resource):
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
tenant_id=current_tenant_id, user_id=current_user.id, args=args
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -115,6 +113,7 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
parser = reqparse.RequestParser()
@@ -136,7 +135,7 @@ class ExternalApiTemplateApi(Resource):
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
args=args,
@@ -148,13 +147,14 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204
@@ -199,7 +199,8 @@ class ExternalDatasetCreateApi(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -223,7 +224,7 @@ class ExternalDatasetCreateApi(Resource):
try:
dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
user_id=current_user.id,
args=args,
)
@@ -255,6 +256,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@login_required
@account_initialization_required
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -277,7 +279,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
account=cast(Account, current_user),
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
)

View File

@@ -1,7 +1,5 @@
import logging
from typing import cast
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -21,6 +19,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=cast(Account, current_user),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,

View File

@@ -1,13 +1,12 @@
from typing import Literal
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
@@ -24,6 +23,7 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
@@ -59,6 +59,7 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
@@ -79,6 +80,7 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -108,6 +110,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -128,6 +131,7 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@@ -1,19 +1,15 @@
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
@@ -24,11 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, provider_id: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
raise Forbidden()
current_user, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
@@ -52,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
user_id=current_user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
@@ -130,9 +126,9 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
@@ -145,7 +141,7 @@ class DatasourceAuth(Resource):
try:
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
@@ -160,8 +156,10 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
@@ -173,18 +171,20 @@ class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
@@ -197,18 +197,20 @@ class DatasourceAuthUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
@@ -224,10 +226,10 @@ class DatasourceAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@@ -237,10 +239,10 @@ class DatasourceHardCodeAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@@ -249,9 +251,10 @@ class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
@@ -259,7 +262,7 @@ class DatasourceAuthOauthCustomClient(Resource):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
@@ -270,10 +273,12 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
@@ -284,16 +289,17 @@ class DatasourceAuthDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
)
@@ -305,9 +311,10 @@ class DatasourceUpdateProviderNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
@@ -315,7 +322,7 @@ class DatasourceUpdateProviderNameApi(Resource):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],

View File

@@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -13,7 +12,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource):
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource):
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
@@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="",
description="",

View File

@@ -23,7 +23,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models.account import Account
from models import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService

View File

@@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -13,8 +10,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@@ -28,7 +24,8 @@ class RagPipelineImportApi(Resource):
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -47,7 +44,7 @@ class RagPipelineImportApi(Resource):
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
@@ -60,9 +57,9 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@@ -74,20 +71,21 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
# Check user role first
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@@ -100,7 +98,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session:
@@ -117,7 +116,8 @@ class RagPipelineExportApi(Resource):
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Add include_secret params

View File

@@ -18,6 +18,7 @@ from controllers.console.app.error import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@@ -36,8 +37,8 @@ from fields.workflow_run_fields import (
)
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError
@@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@@ -79,13 +77,13 @@ class DraftRagPipelineApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline):
"""
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
@@ -154,13 +152,13 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
@@ -194,7 +192,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -229,7 +228,8 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -264,7 +264,8 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -303,7 +304,7 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor:
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
@@ -344,7 +345,7 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor:
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
@@ -385,7 +386,8 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -428,7 +430,8 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -472,7 +475,8 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -505,7 +509,8 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@@ -525,7 +530,8 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
if not pipeline.is_published:
return None
@@ -545,7 +551,8 @@ class PublishedRagPipelineApi(Resource):
Publish workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
rag_pipeline_service = RagPipelineService()
@@ -580,7 +587,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
@@ -599,7 +607,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -631,7 +640,8 @@ class PublishedAllRagPipelineApi(Resource):
"""
Get published workflows
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -681,7 +691,8 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes
"""
# Check permission
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -733,13 +744,11 @@ class PublishedRagPipelineSecondStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@@ -759,13 +768,11 @@ class PublishedRagPipelineFirstStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@@ -785,13 +792,11 @@ class DraftRagPipelineFirstStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@@ -811,13 +816,11 @@ class DraftRagPipelineSecondStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@@ -880,7 +883,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields)
def get(self, pipeline: Pipeline, run_id):
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
"""
@@ -903,14 +906,8 @@ class DatasourceListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
if not isinstance(user, Account):
raise Forbidden()
tenant_id = user.current_tenant_id
if not tenant_id:
raise Forbidden()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
@@ -940,11 +937,11 @@ class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
if not isinstance(current_user, Account):
raise Forbidden()
@edit_permission_required
def post(self, dataset_id: str):
current_user, _ = current_account_with_tenant()
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not current_user.is_dataset_operator:
raise Forbidden()
dataset_id = str(dataset_id)
@@ -959,14 +956,13 @@ class RagPipelineDatasourceVariableApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields)
def post(self, pipeline: Pipeline):
"""
Set datasource variables
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=dict, required=True, location="json")

View File

@@ -3,8 +3,7 @@ from functools import wraps
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
@@ -17,8 +16,7 @@ def get_rag_pipeline(
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account):
raise ValueError("current_user is not an account")
_, current_tenant_id = current_account_with_tenant()
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
@@ -27,7 +25,7 @@ def get_rag_pipeline(
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
)

View File

@@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required
from models import Account, App, InstalledApp, RecommendedApp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
current_user, current_tenant_id = current_account_with_tenant()
if app_id:
installed_apps = db.session.scalars(
@@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("App not found")
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None:
@@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource):
"""
def delete(self, installed_app):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
_, current_tenant_id = current_account_with_tenant()
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app)

View File

@@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -48,6 +47,7 @@ logger = logging.getLogger(__name__)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@@ -61,8 +61,6 @@ class MessageListApi(InstalledAppResource):
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
@@ -78,6 +76,7 @@ class MessageListApi(InstalledAppResource):
)
class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
@@ -88,8 +87,6 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
@@ -109,6 +106,7 @@ class MessageFeedbackApi(InstalledAppResource):
)
class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@@ -124,8 +122,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming"
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this(
app_model=app_model,
user=current_user,
@@ -159,6 +155,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
)
class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -167,8 +164,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)

View File

@@ -7,8 +7,7 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from libs.login import current_user
from models import Account
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@@ -35,6 +34,7 @@ class SavedMessageListApi(InstalledAppResource):
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@@ -44,11 +44,10 @@ class SavedMessageListApi(InstalledAppResource):
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@@ -58,8 +57,6 @@ class SavedMessageListApi(InstalledAppResource):
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@@ -72,6 +69,7 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
@@ -79,8 +77,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204

View File

@@ -2,14 +2,13 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import InstalledApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@@ -24,11 +23,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
installed_app = (
db.session.query(InstalledApp)
.where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
.first()
)
@@ -54,6 +52,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id

View File

@@ -1,11 +1,10 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
@@ -47,7 +46,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension")
@@ -68,14 +67,16 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
@@ -96,7 +97,7 @@ class APIBasedExtensionDetailAPI(Resource):
@marshal_with(api_based_extension_fields)
def get(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@@ -120,9 +121,9 @@ class APIBasedExtensionDetailAPI(Resource):
@marshal_with(api_based_extension_fields)
def post(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
@@ -147,9 +148,9 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
def delete(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db)

View File

@@ -1,7 +1,6 @@
from flask_login import current_user
from flask_restx import Resource, fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.feature_service import FeatureService
from . import api, console_ns
@@ -23,7 +22,9 @@ class FeatureApi(Resource):
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@@ -1,7 +1,6 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with
from werkzeug.exceptions import Forbidden
@@ -22,8 +21,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
from . import console_ns
@@ -53,6 +51,7 @@ class FileApi(Resource):
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@@ -65,16 +64,12 @@ class FileApi(Resource):
if not file.filename:
raise FilenameNotExistsError
if source == "datasets" and not current_user.is_dataset_editor:
raise Forbidden()
if source not in ("datasets", None):
source = None
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
@@ -108,4 +103,4 @@ class FileSupportTypeApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@@ -1,8 +1,6 @@
import urllib.parse
from typing import cast
import httpx
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
@@ -16,7 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from models.account import Account
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
@@ -65,7 +63,7 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user = cast(Account, current_user)
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,

View File

@@ -1,12 +1,11 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService
@@ -24,9 +23,10 @@ class TagListApi(Resource):
@account_initialization_required
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
return tags, 200
@@ -34,8 +34,9 @@ class TagListApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@@ -59,9 +60,10 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def patch(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@@ -81,9 +83,10 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
TagService.delete_tag(tag_id)
@@ -97,8 +100,9 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@@ -123,8 +127,9 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()

View File

@@ -2,11 +2,11 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
P = ParamSpec("P")
@@ -20,8 +20,9 @@ def plugin_permission_required(
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
with Session(db.engine) as session:
permission = (

View File

@@ -2,7 +2,6 @@ from datetime import datetime
import pytz
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -37,9 +36,8 @@ from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required
from models import AccountIntegrate, InvitationCode
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -50,9 +48,7 @@ class AccountInitApi(Resource):
@setup_required
@login_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
if account.status == "active":
raise AccountAlreadyInitedError()
@@ -106,8 +102,7 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields)
@enterprise_license_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
return current_user
@@ -118,8 +113,7 @@ class AccountNameApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
@@ -140,8 +134,7 @@ class AccountAvatarApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
@@ -158,8 +151,7 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
args = parser.parse_args()
@@ -176,8 +168,7 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
args = parser.parse_args()
@@ -194,8 +185,7 @@ class AccountTimezoneApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
@@ -216,8 +206,7 @@ class AccountPasswordApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json")
@@ -253,9 +242,7 @@ class AccountIntegrateApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
@@ -298,9 +285,7 @@ class AccountDeleteVerifyApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code)
@@ -314,9 +299,7 @@ class AccountDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
@@ -358,9 +341,7 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
return BillingService.EducationIdentity.verify(account.id, account.email)
@@ -380,9 +361,7 @@ class EducationApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
@@ -399,9 +378,7 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(status_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id)
# convert expire_at to UTC timestamp from isoformat
@@ -441,6 +418,7 @@ class ChangeEmailSendEmailApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
@@ -467,8 +445,6 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email:
raise InvalidEmailError()
else:
@@ -551,8 +527,7 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()

View File

@@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, fields
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.agent_service import AgentService
@@ -21,10 +20,11 @@ class AgentProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@@ -43,7 +43,5 @@ class AgentProviderApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
current_user, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))

View File

@@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
@@ -6,7 +5,7 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
@@ -34,7 +33,7 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@@ -51,7 +50,7 @@ class EndpointCreateApi(Resource):
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
@@ -80,7 +79,7 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@@ -93,7 +92,7 @@ class EndpointListApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
@@ -123,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@@ -138,7 +137,7 @@ class EndpointListForSinglePluginApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
@@ -165,7 +164,7 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@@ -177,9 +176,7 @@ class EndpointDeleteApi(Resource):
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@@ -207,7 +204,7 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@@ -224,7 +221,7 @@ class EndpointUpdateApi(Resource):
return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
@@ -250,7 +247,7 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@@ -262,9 +259,7 @@ class EndpointEnableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@@ -285,7 +280,7 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@@ -297,7 +292,5 @@ class EndpointDisableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}

View File

@@ -5,8 +5,8 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required
from models.account import Account, TenantAccountRole
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@@ -18,12 +18,11 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -72,12 +71,11 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")

View File

@@ -1,7 +1,6 @@
from urllib import parse
from flask import abort, request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
@@ -26,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@@ -42,8 +41,7 @@ class MemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@@ -70,9 +68,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@@ -121,8 +117,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
@@ -161,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@@ -190,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@@ -213,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@@ -251,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@@ -297,8 +286,7 @@ class OwnerTransfer(Resource):
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@@ -1,7 +1,6 @@
import io
from flask import send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
@@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
@@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
@@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
)
return {"result": "success"}, 204
@@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
)
@@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
model_provider_service = ModelProviderService()
@@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link(
provider_name=provider,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
account_id=current_user.id,
prefilled_email=current_user.email,
)

View File

@@ -1,6 +1,5 @@
import logging
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
"model_type",
@@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"]
@@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
@account_initialization_required
def post(self, provider: str):
# To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
@@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
try:
@@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
try:
model_provider_service.update_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
@@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
@login_required
@account_initialization_required
def get(self, model_type):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@@ -1,7 +1,6 @@
import io
from flask import request, send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
@@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {
@@ -44,7 +43,7 @@ class PluginListApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=1)
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
@@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
@@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
@@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
@@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
@@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
@@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
@@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
@@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@account_initialization_required
def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id
parser = reqparse.RequestParser()
@@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource):
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
@@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
@@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")

View File

@@ -2,7 +2,6 @@ import io
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restx import (
Resource,
reqparse,
@@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
@@ -53,10 +52,9 @@ class ToolProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument(
@@ -78,9 +76,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
@@ -96,9 +92,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@@ -109,11 +103,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
@@ -131,10 +124,9 @@ class ToolBuiltinProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -161,13 +153,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
@@ -193,7 +184,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
@@ -218,13 +209,12 @@ class ToolApiProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -258,10 +248,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -282,10 +271,9 @@ class ToolApiProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -308,13 +296,12 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -350,13 +337,12 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -377,10 +363,9 @@ class ToolApiProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -401,8 +386,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
@@ -444,9 +428,9 @@ class ToolApiProviderPreviousTestApi(Resource):
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
current_tenant_id,
args["provider_name"] or "",
args["tool_name"],
args["credentials"],
@@ -462,13 +446,12 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -502,13 +485,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -545,13 +527,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -571,10 +552,9 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@@ -606,10 +586,9 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@@ -631,10 +610,9 @@ class ToolBuiltinListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -653,8 +631,7 @@ class ToolApiListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
[
@@ -672,10 +649,9 @@ class ToolWorkflowListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -709,19 +685,18 @@ class ToolPluginOAuthApi(Resource):
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
@@ -800,11 +775,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
@@ -819,13 +795,13 @@ class ToolOAuthCustomClient(Resource):
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
@@ -835,20 +811,18 @@ class ToolOAuthCustomClient(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@@ -858,9 +832,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
tenant_id=current_tenant_id, provider_name=provider
)
)
@@ -871,7 +846,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
@@ -900,12 +875,12 @@ class ToolProviderMCPApi(Resource):
)
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
@@ -940,8 +915,9 @@ class ToolProviderMCPApi(Resource):
pass
else:
raise ValueError("Server URL is not valid.")
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
@@ -962,7 +938,8 @@ class ToolProviderMCPApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@@ -977,7 +954,7 @@ class ToolMCPAuthApi(Resource):
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
@@ -1018,8 +995,8 @@ class ToolMCPDetailApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@@ -1029,8 +1006,7 @@ class ToolMCPListAllApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
@@ -1043,7 +1019,7 @@ class ToolMCPUpdateApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,

View File

@@ -1,7 +1,6 @@
import logging
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@@ -24,8 +23,8 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from models.account import Account, Tenant, TenantStatus
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService
from services.file_service import FileService
@@ -71,8 +70,7 @@ class TenantListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
@@ -86,7 +84,7 @@ class TenantListApi(Resource):
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
tenant_dicts.append(tenant_dict)
@@ -131,8 +129,7 @@ class TenantApi(Resource):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
@@ -156,8 +153,7 @@ class SwitchWorkspaceApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
@@ -182,16 +178,12 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
@@ -213,8 +205,7 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@@ -254,15 +245,14 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required
# Change workspace name
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
if not current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args["name"]
db.session.commit()

View File

@@ -7,12 +7,12 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
@@ -29,9 +29,8 @@ def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = current_user
if account.status == AccountStatus.UNINITIALIZED:
current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
return view(*args, **kwargs)
@@ -75,7 +74,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@@ -87,7 +87,8 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
@@ -128,7 +129,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
@@ -151,10 +153,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"
key = f"rate_limit_{current_tenant_id}"
redis_client.zadd(key, {current_time: current_time})
@@ -165,7 +168,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
@@ -185,14 +188,15 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)
@@ -242,9 +246,9 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated
def email_register_enabled(view):
def email_register_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.is_allow_register:
return view(*args, **kwargs)
@@ -271,7 +275,8 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
@@ -281,12 +286,26 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated
def knowledge_pipeline_publish_enabled(view):
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)
return decorated
def edit_permission_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@@ -17,7 +17,7 @@ class BaseMail(Resource):
def post(self):
args = _mail_parser.parse_args()
send_inner_email_task.delay(
send_inner_email_task.delay( # type: ignore
to=args["to"],
subject=args["subject"],
body=args["body"],

View File

@@ -31,7 +31,7 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import length_prefixed_response
from models.account import Account, Tenant
from models import Account, Tenant
from models.model import EndUser

View File

@@ -25,8 +25,8 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with Session(db.engine) as session:
user_model = None
@@ -85,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None):
raise ValueError("tenant_id is required")
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (

View File

@@ -7,7 +7,7 @@ from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from models.account import Account
from models import Account
from services.account_service import TenantService

View File

@@ -10,7 +10,7 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user
from models.account import Account
from models import Account
from models.model import App
from services.annotation_service import AppAnnotationService

View File

@@ -17,7 +17,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models.account import Account
from models import Account
from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError

View File

@@ -1,5 +1,4 @@
from flask import request
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
@@ -16,6 +15,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import current_account_with_tenant
from models.dataset import Dataset
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
@@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
page = request.args.get("page", default=1, type=int)
@@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource):
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
page=page,
@@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource):
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
@@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")
@@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@@ -17,7 +17,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser
from services.feature_service import FeatureService
@@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
@@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id,
)
session.add(end_user)

View File

@@ -20,7 +20,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService

View File

@@ -197,12 +197,12 @@ class DatasetConfigManager:
# strategy
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
has_datasets = False
if config.get("agent_mode", {}).get("strategy") in {
PlanningStrategy.ROUTER.value,
PlanningStrategy.REACT_ROUTER.value,
PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0]

View File

@@ -68,9 +68,13 @@ class ModelConfigConverter:
# get model mode
model_mode = model_config.mode
if not model_mode:
model_mode = LLMMode.CHAT.value
model_mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
try:
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
except ValueError:
# Fall back to CHAT mode if the stored value is invalid
model_mode = LLMMode.CHAT
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@@ -100,7 +100,7 @@ class PromptTemplateConfigManager:
if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
@@ -110,7 +110,7 @@ class PromptTemplateConfigManager:
if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config["model"]["mode"] == ModelMode.CHAT.value:
if config["model"]["mode"] == ModelMode.CHAT:
prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10:

View File

@@ -70,8 +70,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow

View File

@@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode must be of boolean type")
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
agent_mode["strategy"] = PlanningStrategy.ROUTER
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")

View File

@@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner):
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")

View File

@@ -1,10 +1,12 @@
import logging
import queue
import threading
import time
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any
from cachetools import TTLCache, cachedmethod
from redis.exceptions import RedisError
from sqlalchemy.orm import DeclarativeMeta
@@ -45,6 +47,8 @@ class AppQueueManager:
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
self._cache_lock = threading.Lock()
def listen(self):
"""
@@ -157,6 +161,7 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock)
def _is_stopped(self) -> bool:
"""
Check if task is stopped

View File

@@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models import Account
from models.model import App, EndUser
from services.conversation_service import ConversationService

View File

@@ -112,7 +112,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@@ -207,6 +207,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)

View File

@@ -229,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@@ -61,7 +61,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import (

View File

@@ -100,8 +100,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@@ -244,8 +244,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel):
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES:
parameter["type"] = "files"
# -------------

View File

@@ -1,5 +1,5 @@
import enum
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter):
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
@@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
class DatasourceInvokeFrom(StrEnum):
"""
Enum class for datasource invoke
"""

View File

@@ -207,7 +207,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_type == ProviderType.CUSTOM,
Provider.provider_name.in_(self._get_provider_names()),
)
@@ -458,7 +458,7 @@ class ProviderConfiguration(BaseModel):
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
provider_type=ProviderType.CUSTOM,
is_valid=True,
credential_id=new_record.id,
)
@@ -472,6 +472,9 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
else:
# some historical data may have a provider record but not be set as valid
provider_record.is_valid = True
session.commit()
except Exception:
@@ -1145,6 +1148,15 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Can't add same credential")
provider_model_record.credential_id = credential_record.id
provider_model_record.updated_at = naive_utc_now()
# clear cache
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
session.add(provider_model_record)
session.commit()
@@ -1178,6 +1190,14 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
# clear cache
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
def delete_custom_model(self, model_type: ModelType, model: str):
"""
Delete custom model.
@@ -1414,7 +1434,7 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@@ -1,13 +1,13 @@
from typing import cast
import requests
import httpx
from configs import dify_config
from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60)
timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str):
@@ -27,25 +27,23 @@ class APIBasedExtensionRequestor:
url = self.api_endpoint
try:
# proxy support for security
proxies = None
mounts: dict[str, httpx.BaseTransport] | None = None
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = {
"http": dify_config.SSRF_PROXY_HTTP_URL,
"https": dify_config.SSRF_PROXY_HTTPS_URL,
mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
}
response = requests.request(
method="POST",
url=url,
json={"point": point.value, "params": params},
headers=headers,
timeout=self.timeout,
proxies=proxies,
)
except requests.Timeout:
with httpx.Client(mounts=mounts, timeout=self.timeout) as client:
response = client.request(
method="POST",
url=url,
json={"point": point.value, "params": params},
headers=headers,
)
except httpx.TimeoutException:
raise ValueError("request timeout")
except requests.ConnectionError:
except httpx.RequestError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@@ -343,7 +343,7 @@ class IndexingRunner:
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
@@ -356,7 +356,7 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
@@ -379,7 +379,7 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],

View File

@@ -224,8 +224,8 @@ def _handle_native_json_schema(
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA
return model_parameters
@@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list):
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
if ResponseFormat.JSON in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON
elif ResponseFormat.JSON_OBJECT in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT
def _handle_prompt_based_schema(

Some files were not shown because too many files have changed in this diff Show More