mirror of https://github.com/langgenius/dify.git
refactor: select in console explore and workspace controllers (#33842)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2ce2fbc2d4
commit
35cbd83e83
|
|
@ -1,5 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
|
|
@ -17,14 +18,18 @@ class BannerApi(Resource):
|
|||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
|
|
|
|||
|
|
@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource):
|
|||
def post(self):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||
recommended_app = db.session.scalar(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
|
||||
)
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App entity not found")
|
||||
|
|
@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource):
|
|||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Literal, cast
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -476,7 +477,7 @@ class TrialSitApi(Resource):
|
|||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
|
@ -541,13 +542,7 @@ class AppWorkflowApi(Resource):
|
|||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
workflow = db.session.get(Workflow, app_model.workflow_id)
|
||||
return workflow
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar
|
|||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
|
|
@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
|
|
@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
|||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
|
@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
|||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
account_trial_app_record = db.session.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
|
|
|
|||
|
|
@ -212,13 +212,13 @@ class AccountInitApi(Resource):
|
|||
raise ValueError("invitation_code is required")
|
||||
|
||||
# check invitation code
|
||||
invitation_code = (
|
||||
db.session.query(InvitationCode)
|
||||
invitation_code = db.session.scalar(
|
||||
select(InvitationCode)
|
||||
.where(
|
||||
InvitationCode.code == args.invitation_code,
|
||||
InvitationCode.status == InvitationCodeStatus.UNUSED,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not invitation_code:
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
|
|||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if member is None:
|
||||
abort(404)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ class SwitchWorkspaceApi(Resource):
|
|||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
|
||||
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
|
||||
if new_tenant is None:
|
||||
raise ValueError("Tenant not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -24,13 +24,8 @@ class TestBannerApi:
|
|||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = datetime(2024, 1, 1)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = [banner]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [banner]
|
||||
|
||||
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
|
@ -58,16 +53,14 @@ class TestBannerApi:
|
|||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.side_effect = [
|
||||
scalars_result = MagicMock()
|
||||
scalars_result.all.side_effect = [
|
||||
[],
|
||||
[banner],
|
||||
]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
|
@ -87,13 +80,8 @@ class TestBannerApi:
|
|||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = []
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
|
|
|||
|
|
@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi:
|
|||
app_entity.tenant_id = "t2"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
None,
|
||||
]
|
||||
# scalar() is called for recommended_app and installed_app lookups
|
||||
session.scalar.side_effect = [recommended, None]
|
||||
# get() is called for app PK lookup
|
||||
session.get.return_value = app_entity
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
|
|
@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi:
|
|||
method = unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
|
|
@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi:
|
|||
app_entity = MagicMock(is_public=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
]
|
||||
# scalar() returns recommended_app
|
||||
session.scalar.return_value = recommended
|
||||
# get() returns the app entity
|
||||
session.get.return_value = app_entity
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
|
|
|
|||
|
|
@ -958,8 +958,8 @@ class TestTrialSitApi:
|
|||
app_model = MagicMock()
|
||||
app_model.id = "a1"
|
||||
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
|
|
@ -973,8 +973,8 @@ class TestTrialSitApi:
|
|||
app_model.tenant = MagicMock()
|
||||
app_model.tenant.status = TenantStatus.ARCHIVE
|
||||
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = site
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
|
||||
mock_scalar.return_value = site
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
|
|
@ -990,10 +990,10 @@ class TestTrialSitApi:
|
|||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module.db.session, "query") as mock_query,
|
||||
patch.object(module.db.session, "scalar") as mock_scalar,
|
||||
patch.object(module.SiteResponse, "model_validate") as mock_validate,
|
||||
):
|
||||
mock_query.return_value.where.return_value.first.return_value = site
|
||||
mock_scalar.return_value = site
|
||||
mock_validate_result = MagicMock()
|
||||
mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"}
|
||||
mock_validate.return_value = mock_validate_result
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ def test_installed_app_required_not_found():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
scalar_mock.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
|
@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
patch("controllers.console.explore.wraps.db.session.delete"),
|
||||
patch("controllers.console.explore.wraps.db.session.commit"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
scalar_mock.return_value = installed_app
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
|
@ -76,9 +76,9 @@ def test_installed_app_required_success():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
scalar_mock.return_value = installed_app
|
||||
|
||||
result = view("app-id")
|
||||
assert result == installed_app
|
||||
|
|
@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
scalar_mock.return_value = None
|
||||
|
||||
with pytest.raises(TrialAppNotAllowed):
|
||||
view("app-id")
|
||||
|
|
@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
scalar_mock.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
|
@ -194,9 +194,9 @@ def test_trial_app_required_success():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
scalar_mock.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -55,9 +55,9 @@ class TestAccountInitApi:
|
|||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
|
||||
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.workspace.account.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
|
||||
scalar_mock.return_value = MagicMock(status="unused")
|
||||
resp = method(api)
|
||||
|
||||
assert resp["result"] == "success"
|
||||
|
|
|
|||
|
|
@ -207,10 +207,10 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
|
|
@ -226,9 +226,9 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "x")
|
||||
|
|
@ -244,13 +244,13 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
|
|
@ -266,13 +266,13 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
|
|
@ -288,13 +288,13 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
|
|
|
|||
|
|
@ -449,12 +449,12 @@ class TestSwitchWorkspaceApi:
|
|||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
|
||||
),
|
||||
):
|
||||
query_mock.return_value.get.return_value = tenant
|
||||
get_mock.return_value = tenant
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
|
@ -488,9 +488,9 @@ class TestSwitchWorkspaceApi:
|
|||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
):
|
||||
query_mock.return_value.get.return_value = None
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
|
|
|||
Loading…
Reference in New Issue