mirror of https://github.com/langgenius/dify.git
refactor: select in remaining console app controllers (#33969)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
75c3ef82d9
commit
dd4f504b39
|
|
@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
|
||||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
subquery = (
|
subquery = (
|
||||||
db.session.query(
|
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
|
||||||
)
|
|
||||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
|
||||||
|
|
||||||
def _get_conversation(app_model, conversation_id):
|
def _get_conversation(app_model, conversation_id):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
conversation = (
|
conversation = db.session.scalar(
|
||||||
db.session.query(Conversation)
|
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
|
||||||
try:
|
try:
|
||||||
# Generate from nothing for a workflow node
|
# Generate from nothing for a workflow node
|
||||||
if (args.current in (code_template, "")) and args.node_id != "":
|
if (args.current in (code_template, "")) and args.node_id != "":
|
||||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
app = db.session.get(App, args.flow_id)
|
||||||
if not app:
|
if not app:
|
||||||
return {"error": f"app {args.flow_id} not found"}, 400
|
return {"error": f"app {args.flow_id} not found"}, 400
|
||||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import json
|
||||||
|
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource, marshal_with
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_server_model)
|
@marshal_with(app_server_model)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||||
return server
|
return server
|
||||||
|
|
||||||
@console_ns.doc("create_app_mcp_server")
|
@console_ns.doc("create_app_mcp_server")
|
||||||
|
|
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
server = db.session.get(AppMCPServer, payload.id)
|
||||||
if not server:
|
if not server:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
|
|
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, server_id):
|
def get(self, server_id):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
server = (
|
server = db.session.scalar(
|
||||||
db.session.query(AppMCPServer)
|
select(AppMCPServer)
|
||||||
.where(AppMCPServer.id == server_id)
|
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
.limit(1)
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
if not server:
|
if not server:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
|
||||||
|
|
@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
|
||||||
|
|
||||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||||
# get original app model config
|
# get original app model config
|
||||||
original_app_model_config = (
|
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
|
||||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
|
||||||
)
|
|
||||||
if original_app_model_config is None:
|
if original_app_model_config is None:
|
||||||
raise ValueError("Original app model config not found")
|
raise ValueError("Original app model config not found")
|
||||||
agent_mode = original_app_model_config.agent_mode_dict
|
agent_mode = original_app_model_config.agent_mode_dict
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Literal
|
||||||
|
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource, marshal_with
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
|
|
@ -75,7 +76,7 @@ class AppSite(Resource):
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||||
if not site:
|
if not site:
|
||||||
raise NotFound
|
raise NotFound
|
||||||
|
|
||||||
|
|
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||||
@marshal_with(app_site_model)
|
@marshal_with(app_site_model)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||||
|
|
||||||
if not site:
|
if not site:
|
||||||
raise NotFound
|
raise NotFound
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar, Union
|
from typing import ParamSpec, TypeVar, Union
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from controllers.console.app.error import AppNotFoundError
|
from controllers.console.app.error import AppNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
|
|
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
|
||||||
|
|
||||||
def _load_app_model(app_id: str) -> App | None:
|
def _load_app_model(app_id: str) -> App | None:
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
app_model = (
|
app_model = db.session.scalar(
|
||||||
db.session.query(App)
|
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -281,12 +281,10 @@ class TestSiteEndpoints:
|
||||||
method = _unwrap(api.post)
|
method = _unwrap(api.post)
|
||||||
|
|
||||||
site = MagicMock()
|
site = MagicMock()
|
||||||
query = MagicMock()
|
|
||||||
query.where.return_value.first.return_value = site
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
site_module.db,
|
site_module.db,
|
||||||
"session",
|
"session",
|
||||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
site_module,
|
site_module,
|
||||||
|
|
@ -305,12 +303,10 @@ class TestSiteEndpoints:
|
||||||
method = _unwrap(api.post)
|
method = _unwrap(api.post)
|
||||||
|
|
||||||
site = MagicMock()
|
site = MagicMock()
|
||||||
query = MagicMock()
|
|
||||||
query.where.return_value.first.return_value = site
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
site_module.db,
|
site_module.db,
|
||||||
"session",
|
"session",
|
||||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
|
|
|
||||||
|
|
@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
|
||||||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||||
|
|
||||||
query = MagicMock()
|
|
||||||
query.where.return_value = query
|
|
||||||
query.first.return_value = conversation
|
|
||||||
|
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.query.return_value = query
|
session.scalar.return_value = conversation
|
||||||
|
|
||||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||||
|
|
@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
|
||||||
|
|
||||||
|
|
||||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
query = MagicMock()
|
|
||||||
query.where.return_value = query
|
|
||||||
query.first.return_value = None
|
|
||||||
|
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.query.return_value = query
|
session.scalar.return_value = None
|
||||||
|
|
||||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
|
||||||
),
|
),
|
||||||
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
|
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
|
||||||
):
|
):
|
||||||
mock_session.query.return_value.where.return_value.first.return_value = conversation
|
mock_session.scalar.return_value = conversation
|
||||||
|
|
||||||
_get_conversation(app_model, "conversation-id")
|
_get_conversation(app_model, "conversation-id")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
||||||
|
|
||||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||||
|
|
||||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
|
||||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
|
||||||
|
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
"/console/api/instruction-generate",
|
"/console/api/instruction-generate",
|
||||||
|
|
@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
||||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||||
|
|
||||||
app_model = SimpleNamespace(id="app-1")
|
app_model = SimpleNamespace(id="app-1")
|
||||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
|
||||||
_install_workflow_service(monkeypatch, workflow=None)
|
_install_workflow_service(monkeypatch, workflow=None)
|
||||||
|
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
|
|
@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
||||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||||
|
|
||||||
app_model = SimpleNamespace(id="app-1")
|
app_model = SimpleNamespace(id="app-1")
|
||||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
|
||||||
|
|
||||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||||
|
|
@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||||
|
|
||||||
app_model = SimpleNamespace(id="app-1")
|
app_model = SimpleNamespace(id="app-1")
|
||||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
|
||||||
|
|
||||||
workflow = SimpleNamespace(
|
workflow = SimpleNamespace(
|
||||||
graph_dict={
|
graph_dict={
|
||||||
|
|
|
||||||
|
|
@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
|
||||||
)
|
)
|
||||||
|
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
query = MagicMock()
|
session.get.return_value = original_config
|
||||||
query.where.return_value = query
|
|
||||||
query.first.return_value = original_config
|
|
||||||
session.query.return_value = query
|
|
||||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,8 @@ from models.model import AppMode
|
||||||
|
|
||||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
|
||||||
|
|
||||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||||
|
|
||||||
@wraps_module.get_app_model
|
@wraps_module.get_app_model
|
||||||
def handler(app_model):
|
def handler(app_model):
|
||||||
|
|
@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
|
||||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
|
||||||
|
|
||||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||||
|
|
||||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||||
def handler(app_model):
|
def handler(app_model):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue