From 35cbd83e83f6437df621c08accecd6968ba93e01 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Sat, 21 Mar 2026 12:06:17 +0100 Subject: [PATCH] refactor: select in console explore and workspace controllers (#33842) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/explore/banner.py | 11 ++++++--- .../console/explore/installed_app.py | 12 ++++++---- api/controllers/console/explore/trial.py | 11 +++------ api/controllers/console/explore/wraps.py | 15 ++++++------ api/controllers/console/workspace/account.py | 6 ++--- api/controllers/console/workspace/members.py | 2 +- .../console/workspace/workspace.py | 2 +- .../console/explore/test_banner.py | 22 ++++------------- .../console/explore/test_installed_app.py | 19 +++++++-------- .../controllers/console/explore/test_trial.py | 12 +++++----- .../controllers/console/explore/test_wraps.py | 24 +++++++++---------- .../console/workspace/test_accounts.py | 4 ++-- .../console/workspace/test_members.py | 20 ++++++++-------- .../console/workspace/test_workspace.py | 8 +++---- 14 files changed, 79 insertions(+), 89 deletions(-) diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py index 5dfef6bf6a..757061d8dd 100644 --- a/api/controllers/console/explore/banner.py +++ b/api/controllers/console/explore/banner.py @@ -1,5 +1,6 @@ from flask import request from flask_restx import Resource +from sqlalchemy import select from controllers.console import api from controllers.console.explore.wraps import explore_banner_enabled @@ -17,14 +18,18 @@ class BannerApi(Resource): language = request.args.get("language", "en-US") # Build base query for enabled banners - base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED) + base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED) # Try to get banners in the requested language - banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort) + ).all() # Fallback to en-US if no banners found and language is not en-US if not banners and language != "en-US": - banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort) + ).all() # Convert banners to serializable format result = [] for banner in banners: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index aca766567f..0740dd0e24 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource): def post(self): payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() + recommended_app = db.session.scalar( + select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1) + ) if recommended_app is None: raise NotFound("Recommended app not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == payload.app_id).first() + app = db.session.get(App, payload.app_id) if app is None: raise NotFound("App entity not found") @@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource): if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) - .first() + .limit(1) ) if installed_app is None: diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 25bb8ed7fe..a8d8036f0f 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -476,7 +477,7 @@ class TrialSitApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() @@ -541,13 +542,7 @@ class AppWorkflowApi(Resource): if not app_model.workflow_id: raise AppUnavailableError() - workflow = ( - db.session.query(Workflow) - .where( - Workflow.id == app_model.workflow_id, - ) - .first() - ) + workflow = db.session.get(Workflow, app_model.workflow_id) return workflow diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 03edb871e6..9d9337e63e 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import abort from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed @@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): _, current_tenant_id = current_account_with_tenant() - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id) - .first() + .limit(1) ) if installed_app is None: @@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): current_user, _ = current_account_with_tenant() - trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first() + trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1)) if trial_app is None: raise TrialAppNotAllowed() @@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): if app is None: raise TrialAppNotAllowed() - account_trial_app_record = ( - db.session.query(AccountTrialAppRecord) + account_trial_app_record = db.session.scalar( + select(AccountTrialAppRecord) .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id) - .first() + .limit(1) ) if account_trial_app_record: if account_trial_app_record.count >= trial_app.trial_limit: diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 0d8960c9bd..6f93ff1e70 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -212,13 +212,13 @@ class AccountInitApi(Resource): raise ValueError("invitation_code is required") # check invitation code - invitation_code = ( - db.session.query(InvitationCode) + invitation_code = db.session.scalar( + select(InvitationCode) .where( InvitationCode.code == args.invitation_code, InvitationCode.status == InvitationCodeStatus.UNUSED, ) - .first() + .limit(1) ) if not invitation_code: diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index dd302b90d6..e3bf4c95b8 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") - member = db.session.query(Account).where(Account.id == str(member_id)).first() + member = db.session.get(Account, str(member_id)) if member is None: abort(404) else: diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index e65df87b2b..88fd2c010f 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -220,7 +220,7 @@ class SwitchWorkspaceApi(Resource): except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant + new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant if new_tenant is None: raise ValueError("Tenant not found") diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py index 4414f1eb5f..c8f674f515 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_banner.py +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -24,13 +24,8 @@ class TestBannerApi: banner.status = BannerStatus.ENABLED banner.created_at = datetime(2024, 1, 1) - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.return_value = [banner] - session = MagicMock() - session.query.return_value = query + session.scalars.return_value.all.return_value = [banner] with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session): result = method(api) @@ -58,16 +53,14 @@ class TestBannerApi: banner.status = BannerStatus.ENABLED banner.created_at = None - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.side_effect = [ + scalars_result = MagicMock() + scalars_result.all.side_effect = [ [], [banner], ] session = MagicMock() - session.query.return_value = query + session.scalars.return_value = scalars_result with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session): result = method(api) @@ -87,13 +80,8 @@ class TestBannerApi: api = banner_module.BannerApi() method = unwrap(api.get) - query = MagicMock() - query.where.return_value = query - query.order_by.return_value = query - query.all.return_value = [] - session = MagicMock() - session.query.return_value = query + session.scalars.return_value.all.return_value = [] with app.test_request_context("/"), patch.object(banner_module.db, "session", session): result = method(api) diff --git a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py index 3983a6a97e..93652e75d2 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py @@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi: app_entity.tenant_id = "t2" session = MagicMock() - session.query.return_value.where.return_value.first.side_effect = [ - recommended, - app_entity, - None, - ] + # scalar() is called for recommended_app and installed_app lookups + session.scalar.side_effect = [recommended, None] + # get() is called for app PK lookup + session.get.return_value = app_entity with ( app.test_request_context("/", json={"app_id": "a1"}), @@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi: method = unwrap(api.post) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with ( app.test_request_context("/", json={"app_id": "a1"}), @@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi: app_entity = MagicMock(is_public=False) session = MagicMock() - session.query.return_value.where.return_value.first.side_effect = [ - recommended, - app_entity, - ] + # scalar() returns recommended_app + session.scalar.return_value = recommended + # get() returns the app entity + session.get.return_value = app_entity with ( app.test_request_context("/", json={"app_id": "a1"}), diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index d85114c8fb..5a03daecbc 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -958,8 +958,8 @@ class TestTrialSitApi: app_model = MagicMock() app_model.id = "a1" - with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = None + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = None with pytest.raises(Forbidden): method(api, app_model) @@ -973,8 +973,8 @@ class TestTrialSitApi: app_model.tenant = MagicMock() app_model.tenant.status = TenantStatus.ARCHIVE - with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = site + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = site with pytest.raises(Forbidden): method(api, app_model) @@ -990,10 +990,10 @@ class TestTrialSitApi: with ( app.test_request_context("/"), - patch.object(module.db.session, "query") as mock_query, + patch.object(module.db.session, "scalar") as mock_scalar, patch.object(module.SiteResponse, "model_validate") as mock_validate, ): - mock_query.return_value.where.return_value.first.return_value = site + mock_scalar.return_value = site mock_validate_result = MagicMock() mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"} mock_validate.return_value = mock_validate_result diff --git a/api/tests/unit_tests/controllers/console/explore/test_wraps.py b/api/tests/unit_tests/controllers/console/explore/test_wraps.py index 67e7a32591..2c1acfc3d6 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/explore/test_wraps.py @@ -34,9 +34,9 @@ def test_installed_app_required_not_found(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = None + scalar_mock.return_value = None with pytest.raises(NotFound): view("app-id") @@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, patch("controllers.console.explore.wraps.db.session.delete"), patch("controllers.console.explore.wraps.db.session.commit"), ): - q.return_value.where.return_value.first.return_value = installed_app + scalar_mock.return_value = installed_app with pytest.raises(NotFound): view("app-id") @@ -76,9 +76,9 @@ def test_installed_app_required_success(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = installed_app + scalar_mock.return_value = installed_app result = view("app-id") assert result == installed_app @@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.return_value = None + scalar_mock.return_value = None with pytest.raises(TrialAppNotAllowed): view("app-id") @@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.side_effect = [ + scalar_mock.side_effect = [ trial_app, record, ] @@ -194,9 +194,9 @@ def test_trial_app_required_success(): "controllers.console.explore.wraps.current_account_with_tenant", return_value=(MagicMock(id="user-1"), None), ), - patch("controllers.console.explore.wraps.db.session.query") as q, + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, ): - q.return_value.where.return_value.first.side_effect = [ + scalar_mock.side_effect = [ trial_app, record, ] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index 00d322fdea..42be02cdaf 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -55,9 +55,9 @@ class TestAccountInitApi: patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), patch("controllers.console.workspace.account.db.session.commit", return_value=None), patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"), - patch("controllers.console.workspace.account.db.session.query") as query_mock, + patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock, ): - query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused") + scalar_mock.return_value = MagicMock(status="unused") resp = method(api) assert resp["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index b6708d1f6f..718b57ba6b 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -207,10 +207,10 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 200 @@ -226,9 +226,9 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, ): - q.return_value.where.return_value.first.return_value = None + get_mock.return_value = None with pytest.raises(HTTPException): method(api, "x") @@ -244,13 +244,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.CannotOperateSelfError("x"), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 400 @@ -266,13 +266,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.NoPermissionError("x"), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 403 @@ -288,13 +288,13 @@ class TestMemberCancelInviteApi: with ( app.test_request_context("/"), patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), - patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.db.session.get") as get_mock, patch( "controllers.console.workspace.members.TenantService.remove_member_from_tenant", side_effect=services.errors.account.MemberNotInTenantError(), ), ): - q.return_value.where.return_value.first.return_value = member + get_mock.return_value = member result, status = method(api, member.id) assert status == 404 diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index aa881f0521..f5ebe0b534 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -449,12 +449,12 @@ class TestSwitchWorkspaceApi: "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), - patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"} ), ): - query_mock.return_value.get.return_value = tenant + get_mock.return_value = tenant result = method(api) assert result["result"] == "success" @@ -488,9 +488,9 @@ class TestSwitchWorkspaceApi: return_value=(MagicMock(), "t1"), ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), - patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, ): - query_mock.return_value.get.return_value = None + get_mock.return_value = None with pytest.raises(ValueError): method(api)