refactor: select in console explore and workspace controllers (#33842)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-03-21 12:06:17 +01:00 committed by GitHub
parent 2ce2fbc2d4
commit 35cbd83e83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 79 additions and 89 deletions

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
@ -17,14 +18,18 @@ class BannerApi(Resource):
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
banners = db.session.scalars(
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
banners = db.session.scalars(
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
).all()
# Convert banners to serializable format
result = []
for banner in banners:

View File

@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource):
def post(self):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
recommended_app = db.session.scalar(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
)
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == payload.app_id).first()
app = db.session.get(App, payload.app_id)
if app is None:
raise NotFound("App entity not found")
@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource):
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = (
db.session.query(InstalledApp)
installed_app = db.session.scalar(
select(InstalledApp)
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first()
.limit(1)
)
if installed_app is None:

View File

@ -4,6 +4,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@ -476,7 +477,7 @@ class TrialSitApi(Resource):
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise Forbidden()
@ -541,13 +542,7 @@ class AppWorkflowApi(Resource):
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
workflow = db.session.get(Workflow, app_model.workflow_id)
return workflow

View File

@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
installed_app = (
db.session.query(InstalledApp)
installed_app = db.session.scalar(
select(InstalledApp)
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if installed_app is None:
@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
if trial_app is None:
raise TrialAppNotAllowed()
@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
account_trial_app_record = db.session.scalar(
select(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
.limit(1)
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:

View File

@ -212,13 +212,13 @@ class AccountInitApi(Resource):
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
invitation_code = db.session.scalar(
select(InvitationCode)
.where(
InvitationCode.code == args.invitation_code,
InvitationCode.status == InvitationCodeStatus.UNUSED,
)
.first()
.limit(1)
)
if not invitation_code:

View File

@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
member = db.session.get(Account, str(member_id))
if member is None:
abort(404)
else:

View File

@ -220,7 +220,7 @@ class SwitchWorkspaceApi(Resource):
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")

View File

@ -24,13 +24,8 @@ class TestBannerApi:
banner.status = BannerStatus.ENABLED
banner.created_at = datetime(2024, 1, 1)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = [banner]
session = MagicMock()
session.query.return_value = query
session.scalars.return_value.all.return_value = [banner]
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
result = method(api)
@ -58,16 +53,14 @@ class TestBannerApi:
banner.status = BannerStatus.ENABLED
banner.created_at = None
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.side_effect = [
scalars_result = MagicMock()
scalars_result.all.side_effect = [
[],
[banner],
]
session = MagicMock()
session.query.return_value = query
session.scalars.return_value = scalars_result
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
result = method(api)
@ -87,13 +80,8 @@ class TestBannerApi:
api = banner_module.BannerApi()
method = unwrap(api.get)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = []
session = MagicMock()
session.query.return_value = query
session.scalars.return_value.all.return_value = []
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
result = method(api)

View File

@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi:
app_entity.tenant_id = "t2"
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
None,
]
# scalar() is called for recommended_app and installed_app lookups
session.scalar.side_effect = [recommended, None]
# get() is called for app PK lookup
session.get.return_value = app_entity
with (
app.test_request_context("/", json={"app_id": "a1"}),
@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi:
method = unwrap(api.post)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with (
app.test_request_context("/", json={"app_id": "a1"}),
@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi:
app_entity = MagicMock(is_public=False)
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
]
# scalar() returns recommended_app
session.scalar.return_value = recommended
# get() returns the app entity
session.get.return_value = app_entity
with (
app.test_request_context("/", json={"app_id": "a1"}),

View File

@ -958,8 +958,8 @@ class TestTrialSitApi:
app_model = MagicMock()
app_model.id = "a1"
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
mock_scalar.return_value = None
with pytest.raises(Forbidden):
method(api, app_model)
@ -973,8 +973,8 @@ class TestTrialSitApi:
app_model.tenant = MagicMock()
app_model.tenant.status = TenantStatus.ARCHIVE
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = site
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
mock_scalar.return_value = site
with pytest.raises(Forbidden):
method(api, app_model)
@ -990,10 +990,10 @@ class TestTrialSitApi:
with (
app.test_request_context("/"),
patch.object(module.db.session, "query") as mock_query,
patch.object(module.db.session, "scalar") as mock_scalar,
patch.object(module.SiteResponse, "model_validate") as mock_validate,
):
mock_query.return_value.where.return_value.first.return_value = site
mock_scalar.return_value = site
mock_validate_result = MagicMock()
mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"}
mock_validate.return_value = mock_validate_result

View File

@ -34,9 +34,9 @@ def test_installed_app_required_not_found():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = None
scalar_mock.return_value = None
with pytest.raises(NotFound):
view("app-id")
@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
patch("controllers.console.explore.wraps.db.session.delete"),
patch("controllers.console.explore.wraps.db.session.commit"),
):
q.return_value.where.return_value.first.return_value = installed_app
scalar_mock.return_value = installed_app
with pytest.raises(NotFound):
view("app-id")
@ -76,9 +76,9 @@ def test_installed_app_required_success():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = installed_app
scalar_mock.return_value = installed_app
result = view("app-id")
assert result == installed_app
@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = None
scalar_mock.return_value = None
with pytest.raises(TrialAppNotAllowed):
view("app-id")
@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.side_effect = [
scalar_mock.side_effect = [
trial_app,
record,
]
@ -194,9 +194,9 @@ def test_trial_app_required_success():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.side_effect = [
scalar_mock.side_effect = [
trial_app,
record,
]

View File

@ -55,9 +55,9 @@ class TestAccountInitApi:
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.query") as query_mock,
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
):
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
scalar_mock.return_value = MagicMock(status="unused")
resp = method(api)
assert resp["result"] == "success"

View File

@ -207,10 +207,10 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 200
@ -226,9 +226,9 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
):
q.return_value.where.return_value.first.return_value = None
get_mock.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
@ -244,13 +244,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 400
@ -266,13 +266,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 403
@ -288,13 +288,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 404

View File

@ -449,12 +449,12 @@ class TestSwitchWorkspaceApi:
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
),
):
query_mock.return_value.get.return_value = tenant
get_mock.return_value = tenant
result = method(api)
assert result["result"] == "success"
@ -488,9 +488,9 @@ class TestSwitchWorkspaceApi:
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
):
query_mock.return_value.get.return_value = None
get_mock.return_value = None
with pytest.raises(ValueError):
method(api)