mirror of https://github.com/langgenius/dify.git
feat: app trial (#26281)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: hj24 <mambahj24@gmail.com>
This commit is contained in:
parent
136618b567
commit
515002a8ba
|
|
@ -965,6 +965,16 @@ class MailConfig(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ENABLE_TRIAL_APP: bool = Field(
|
||||||
|
description="Enable trial app",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||||
|
description="Enable explore banner",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RagEtlConfig(BaseSettings):
|
class RagEtlConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -107,10 +107,12 @@ from .datasets.rag_pipeline import (
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import (
|
from .explore import (
|
||||||
|
banner,
|
||||||
installed_app,
|
installed_app,
|
||||||
parameter,
|
parameter,
|
||||||
recommended_app,
|
recommended_app,
|
||||||
saved_message,
|
saved_message,
|
||||||
|
trial,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import tag controllers
|
# Import tag controllers
|
||||||
|
|
@ -145,6 +147,7 @@ __all__ = [
|
||||||
"apikey",
|
"apikey",
|
||||||
"app",
|
"app",
|
||||||
"audio",
|
"audio",
|
||||||
|
"banner",
|
||||||
"billing",
|
"billing",
|
||||||
"bp",
|
"bp",
|
||||||
"completion",
|
"completion",
|
||||||
|
|
@ -198,6 +201,7 @@ __all__ = [
|
||||||
"statistic",
|
"statistic",
|
||||||
"tags",
|
"tags",
|
||||||
"tool_providers",
|
"tool_providers",
|
||||||
|
"trial",
|
||||||
"trigger_providers",
|
"trigger_providers",
|
||||||
"version",
|
"version",
|
||||||
"website",
|
"website",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
|
||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_access_token
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
|
||||||
language: str = Field(...)
|
language: str = Field(...)
|
||||||
category: str = Field(...)
|
category: str = Field(...)
|
||||||
position: int = Field(...)
|
position: int = Field(...)
|
||||||
|
can_trial: bool = Field(default=False)
|
||||||
|
trial_limit: int = Field(default=0)
|
||||||
|
|
||||||
@field_validator("language")
|
@field_validator("language")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
|
||||||
return supported_language(value)
|
return supported_language(value)
|
||||||
|
|
||||||
|
|
||||||
|
class InsertExploreBannerPayload(BaseModel):
|
||||||
|
category: str = Field(...)
|
||||||
|
title: str = Field(...)
|
||||||
|
description: str = Field(...)
|
||||||
|
img_src: str = Field(..., alias="img-src")
|
||||||
|
language: str = Field(default="en-US")
|
||||||
|
link: str = Field(...)
|
||||||
|
sort: int = Field(...)
|
||||||
|
|
||||||
|
@field_validator("language")
|
||||||
|
@classmethod
|
||||||
|
def validate_language(cls, value: str) -> str:
|
||||||
|
return supported_language(value)
|
||||||
|
|
||||||
|
model_config = {"populate_by_name": True}
|
||||||
|
|
||||||
|
|
||||||
console_ns.schema_model(
|
console_ns.schema_model(
|
||||||
InsertExploreAppPayload.__name__,
|
InsertExploreAppPayload.__name__,
|
||||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
InsertExploreBannerPayload.__name__,
|
||||||
|
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def admin_required(view: Callable[P, R]):
|
def admin_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
|
|
@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(recommended_app)
|
db.session.add(recommended_app)
|
||||||
|
if payload.can_trial:
|
||||||
|
trial_app = db.session.execute(
|
||||||
|
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if not trial_app:
|
||||||
|
db.session.add(
|
||||||
|
TrialApp(
|
||||||
|
app_id=payload.app_id,
|
||||||
|
tenant_id=app.tenant_id,
|
||||||
|
trial_limit=payload.trial_limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
trial_app.trial_limit = payload.trial_limit
|
||||||
|
|
||||||
app.is_public = True
|
app.is_public = True
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource):
|
||||||
recommended_app.category = payload.category
|
recommended_app.category = payload.category
|
||||||
recommended_app.position = payload.position
|
recommended_app.position = payload.position
|
||||||
|
|
||||||
|
if payload.can_trial:
|
||||||
|
trial_app = db.session.execute(
|
||||||
|
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if not trial_app:
|
||||||
|
db.session.add(
|
||||||
|
TrialApp(
|
||||||
|
app_id=payload.app_id,
|
||||||
|
tenant_id=app.tenant_id,
|
||||||
|
trial_limit=payload.trial_limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
trial_app.trial_limit = payload.trial_limit
|
||||||
app.is_public = True
|
app.is_public = True
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource):
|
||||||
for installed_app in installed_apps:
|
for installed_app in installed_apps:
|
||||||
session.delete(installed_app)
|
session.delete(installed_app)
|
||||||
|
|
||||||
|
trial_app = session.execute(
|
||||||
|
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if trial_app:
|
||||||
|
session.delete(trial_app)
|
||||||
|
|
||||||
db.session.delete(recommended_app)
|
db.session.delete(recommended_app)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/admin/insert-explore-banner")
|
||||||
|
class InsertExploreBannerApi(Resource):
|
||||||
|
@console_ns.doc("insert_explore_banner")
|
||||||
|
@console_ns.doc(description="Insert an explore banner")
|
||||||
|
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
|
||||||
|
@console_ns.response(201, "Banner inserted successfully")
|
||||||
|
@only_edition_cloud
|
||||||
|
@admin_required
|
||||||
|
def post(self):
|
||||||
|
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
|
content = {
|
||||||
|
"category": payload.category,
|
||||||
|
"title": payload.title,
|
||||||
|
"description": payload.description,
|
||||||
|
"img-src": payload.img_src,
|
||||||
|
}
|
||||||
|
|
||||||
|
banner = ExporleBanner(
|
||||||
|
content=content,
|
||||||
|
link=payload.link,
|
||||||
|
sort=payload.sort,
|
||||||
|
language=payload.language,
|
||||||
|
)
|
||||||
|
db.session.add(banner)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||||
|
class DeleteExploreBannerApi(Resource):
|
||||||
|
@console_ns.doc("delete_explore_banner")
|
||||||
|
@console_ns.doc(description="Delete an explore banner")
|
||||||
|
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
|
||||||
|
@console_ns.response(204, "Banner deleted successfully")
|
||||||
|
@only_edition_cloud
|
||||||
|
@admin_required
|
||||||
|
def delete(self, banner_id):
|
||||||
|
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||||
|
if not banner:
|
||||||
|
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||||
|
|
||||||
|
db.session.delete(banner)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {"result": "success"}, 204
|
||||||
|
|
|
||||||
|
|
@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||||
error_code = "rate_limit_error"
|
error_code = "rate_limit_error"
|
||||||
description = "Rate Limit Error"
|
description = "Rate Limit Error"
|
||||||
code = 429
|
code = 429
|
||||||
|
|
||||||
|
|
||||||
|
class NeedAddIdsError(BaseHTTPException):
|
||||||
|
error_code = "need_add_ids"
|
||||||
|
description = "Need to add ids."
|
||||||
|
code = 400
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
|
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||||
|
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||||
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||||
def decorator(view_func: Callable[P1, R1]):
|
def decorator(view_func: Callable[P1, R1]):
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
|
|
@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
||||||
return decorator
|
return decorator
|
||||||
else:
|
else:
|
||||||
return decorator(view)
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||||
|
def decorator(view_func: Callable[P, R]):
|
||||||
|
@wraps(view_func)
|
||||||
|
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
if not kwargs.get("app_id"):
|
||||||
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
|
app_id = kwargs.get("app_id")
|
||||||
|
app_id = str(app_id)
|
||||||
|
|
||||||
|
del kwargs["app_id"]
|
||||||
|
|
||||||
|
app_model = _load_app_model_with_trial(app_id)
|
||||||
|
|
||||||
|
if not app_model:
|
||||||
|
raise AppNotFoundError()
|
||||||
|
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
|
||||||
|
if mode is not None:
|
||||||
|
if isinstance(mode, list):
|
||||||
|
modes = mode
|
||||||
|
else:
|
||||||
|
modes = [mode]
|
||||||
|
|
||||||
|
if app_mode not in modes:
|
||||||
|
mode_values = {m.value for m in modes}
|
||||||
|
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||||
|
|
||||||
|
kwargs["app_model"] = app_model
|
||||||
|
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
if view is None:
|
||||||
|
return decorator
|
||||||
|
else:
|
||||||
|
return decorator(view)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.explore.wraps import explore_banner_enabled
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import ExporleBanner
|
||||||
|
|
||||||
|
|
||||||
|
class BannerApi(Resource):
|
||||||
|
"""Resource for banner list."""
|
||||||
|
|
||||||
|
@explore_banner_enabled
|
||||||
|
def get(self):
|
||||||
|
"""Get banner list."""
|
||||||
|
language = request.args.get("language", "en-US")
|
||||||
|
|
||||||
|
# Build base query for enabled banners
|
||||||
|
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||||
|
|
||||||
|
# Try to get banners in the requested language
|
||||||
|
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||||
|
|
||||||
|
# Fallback to en-US if no banners found and language is not en-US
|
||||||
|
if not banners and language != "en-US":
|
||||||
|
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||||
|
# Convert banners to serializable format
|
||||||
|
result = []
|
||||||
|
for banner in banners:
|
||||||
|
banner_data = {
|
||||||
|
"id": banner.id,
|
||||||
|
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||||
|
"link": banner.link,
|
||||||
|
"sort": banner.sort,
|
||||||
|
"status": banner.status,
|
||||||
|
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||||
|
}
|
||||||
|
result.append(banner_data)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(BannerApi, "/explore/banners")
|
||||||
|
|
@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
|
||||||
error_code = "access_denied"
|
error_code = "access_denied"
|
||||||
description = "App access denied."
|
description = "App access denied."
|
||||||
code = 403
|
code = 403
|
||||||
|
|
||||||
|
|
||||||
|
class TrialAppNotAllowed(BaseHTTPException):
|
||||||
|
"""*403* `Trial App Not Allowed`
|
||||||
|
|
||||||
|
Raise if the user has reached the trial app limit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error_code = "trial_app_not_allowed"
|
||||||
|
code = 403
|
||||||
|
description = "the app is not allowed to be trial."
|
||||||
|
|
||||||
|
|
||||||
|
class TrialAppLimitExceeded(BaseHTTPException):
|
||||||
|
"""*403* `Trial App Limit Exceeded`
|
||||||
|
|
||||||
|
Raise if the user has exceeded the trial app limit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
error_code = "trial_app_limit_exceeded"
|
||||||
|
code = 403
|
||||||
|
description = "The user has exceeded the trial app limit."
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ recommended_app_fields = {
|
||||||
"category": fields.String,
|
"category": fields.String,
|
||||||
"position": fields.Integer,
|
"position": fields.Integer,
|
||||||
"is_listed": fields.Boolean,
|
"is_listed": fields.Boolean,
|
||||||
|
"can_trial": fields.Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
recommended_app_list_fields = {
|
recommended_app_list_fields = {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,512 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||||
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.common.fields import Parameters as ParametersResponse
|
||||||
|
from controllers.common.fields import Site as SiteResponse
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
AppUnavailableError,
|
||||||
|
AudioTooLargeError,
|
||||||
|
CompletionRequestError,
|
||||||
|
ConversationCompletedError,
|
||||||
|
NeedAddIdsError,
|
||||||
|
NoAudioUploadedError,
|
||||||
|
ProviderModelCurrentlyNotSupportError,
|
||||||
|
ProviderNotInitializeError,
|
||||||
|
ProviderNotSupportSpeechToTextError,
|
||||||
|
ProviderQuotaExceededError,
|
||||||
|
UnsupportedAudioTypeError,
|
||||||
|
)
|
||||||
|
from controllers.console.app.wraps import get_app_model_with_trial
|
||||||
|
from controllers.console.explore.error import (
|
||||||
|
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||||
|
NotChatAppError,
|
||||||
|
NotCompletionAppError,
|
||||||
|
NotWorkflowAppError,
|
||||||
|
)
|
||||||
|
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||||
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.app_fields import app_detail_fields_with_site
|
||||||
|
from fields.dataset_fields import dataset_fields
|
||||||
|
from fields.workflow_fields import workflow_fields
|
||||||
|
from libs import helper
|
||||||
|
from libs.helper import uuid_value
|
||||||
|
from libs.login import current_user
|
||||||
|
from models import Account
|
||||||
|
from models.account import TenantStatus
|
||||||
|
from models.model import AppMode, Site
|
||||||
|
from models.workflow import Workflow
|
||||||
|
from services.app_generate_service import AppGenerateService
|
||||||
|
from services.app_service import AppService
|
||||||
|
from services.audio_service import AudioService
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.errors.audio import (
|
||||||
|
AudioTooLargeServiceError,
|
||||||
|
NoAudioUploadedServiceError,
|
||||||
|
ProviderNotSupportSpeechToTextServiceError,
|
||||||
|
UnsupportedAudioTypeServiceError,
|
||||||
|
)
|
||||||
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
from services.errors.message import (
|
||||||
|
MessageNotExistsError,
|
||||||
|
SuggestedQuestionsAfterAnswerDisabledError,
|
||||||
|
)
|
||||||
|
from services.message_service import MessageService
|
||||||
|
from services.recommended_app_service import RecommendedAppService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||||
|
def post(self, trial_app):
|
||||||
|
"""
|
||||||
|
Run workflow
|
||||||
|
"""
|
||||||
|
app_model = trial_app
|
||||||
|
if not app_model:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
if app_mode != AppMode.WORKFLOW:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert current_user is not None
|
||||||
|
try:
|
||||||
|
app_id = app_model.id
|
||||||
|
user_id = current_user.id
|
||||||
|
response = AppGenerateService.generate(
|
||||||
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
|
)
|
||||||
|
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||||
|
return helper.compact_generate_response(response)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except InvokeRateLimitError as ex:
|
||||||
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception:
|
||||||
|
logger.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||||
|
def post(self, trial_app, task_id: str):
|
||||||
|
"""
|
||||||
|
Stop workflow task
|
||||||
|
"""
|
||||||
|
app_model = trial_app
|
||||||
|
if not app_model:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
if app_mode != AppMode.WORKFLOW:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
|
assert current_user is not None
|
||||||
|
|
||||||
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class TrialChatApi(TrialAppResource):
|
||||||
|
@trial_feature_enable
|
||||||
|
def post(self, trial_app):
|
||||||
|
app_model = trial_app
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
|
raise NotChatAppError()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
|
parser.add_argument("query", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
|
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||||
|
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
|
|
||||||
|
# Get IDs before they might be detached from session
|
||||||
|
app_id = app_model.id
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
response = AppGenerateService.generate(
|
||||||
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
|
)
|
||||||
|
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||||
|
return helper.compact_generate_response(response)
|
||||||
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
except services.errors.conversation.ConversationCompletedError:
|
||||||
|
raise ConversationCompletedError()
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logger.exception("App model config broken.")
|
||||||
|
raise AppUnavailableError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except InvokeRateLimitError as ex:
|
||||||
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception:
|
||||||
|
logger.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||||
|
@trial_feature_enable
|
||||||
|
def get(self, trial_app, message_id):
|
||||||
|
app_model = trial_app
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
|
raise NotChatAppError()
|
||||||
|
|
||||||
|
message_id = str(message_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
|
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||||
|
)
|
||||||
|
except MessageNotExistsError:
|
||||||
|
raise NotFound("Message not found")
|
||||||
|
except ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation not found")
|
||||||
|
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||||
|
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
return {"data": questions}
|
||||||
|
|
||||||
|
|
||||||
|
class TrialChatAudioApi(TrialAppResource):
|
||||||
|
@trial_feature_enable
|
||||||
|
def post(self, trial_app):
|
||||||
|
app_model = trial_app
|
||||||
|
|
||||||
|
file = request.files["file"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
|
|
||||||
|
# Get IDs before they might be detached from session
|
||||||
|
app_id = app_model.id
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||||
|
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||||
|
return response
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logger.exception("App model config broken.")
|
||||||
|
raise AppUnavailableError()
|
||||||
|
except NoAudioUploadedServiceError:
|
||||||
|
raise NoAudioUploadedError()
|
||||||
|
except AudioTooLargeServiceError as e:
|
||||||
|
raise AudioTooLargeError(str(e))
|
||||||
|
except UnsupportedAudioTypeServiceError:
|
||||||
|
raise UnsupportedAudioTypeError()
|
||||||
|
except ProviderNotSupportSpeechToTextServiceError:
|
||||||
|
raise ProviderNotSupportSpeechToTextError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
class TrialChatTextApi(TrialAppResource):
|
||||||
|
@trial_feature_enable
|
||||||
|
def post(self, trial_app):
|
||||||
|
app_model = trial_app
|
||||||
|
try:
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||||
|
parser.add_argument("voice", type=str, location="json")
|
||||||
|
parser.add_argument("text", type=str, location="json")
|
||||||
|
parser.add_argument("streaming", type=bool, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
message_id = args.get("message_id", None)
|
||||||
|
text = args.get("text", None)
|
||||||
|
voice = args.get("voice", None)
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
|
|
||||||
|
# Get IDs before they might be detached from session
|
||||||
|
app_id = app_model.id
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||||
|
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||||
|
return response
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logger.exception("App model config broken.")
|
||||||
|
raise AppUnavailableError()
|
||||||
|
except NoAudioUploadedServiceError:
|
||||||
|
raise NoAudioUploadedError()
|
||||||
|
except AudioTooLargeServiceError as e:
|
||||||
|
raise AudioTooLargeError(str(e))
|
||||||
|
except UnsupportedAudioTypeServiceError:
|
||||||
|
raise UnsupportedAudioTypeError()
|
||||||
|
except ProviderNotSupportSpeechToTextServiceError:
|
||||||
|
raise ProviderNotSupportSpeechToTextError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
class TrialCompletionApi(TrialAppResource):
|
||||||
|
@trial_feature_enable
|
||||||
|
def post(self, trial_app):
|
||||||
|
app_model = trial_app
|
||||||
|
if app_model.mode != "completion":
|
||||||
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
|
parser.add_argument("query", type=str, location="json", default="")
|
||||||
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
|
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
|
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
args["auto_generate_name"] = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user must be an Account instance")
|
||||||
|
|
||||||
|
# Get IDs before they might be detached from session
|
||||||
|
app_id = app_model.id
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
response = AppGenerateService.generate(
|
||||||
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||||
|
return helper.compact_generate_response(response)
|
||||||
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
except services.errors.conversation.ConversationCompletedError:
|
||||||
|
raise ConversationCompletedError()
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logger.exception("App model config broken.")
|
||||||
|
raise AppUnavailableError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception:
|
||||||
|
logger.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
class TrialSitApi(Resource):
|
||||||
|
"""Resource for trial app sites."""
|
||||||
|
|
||||||
|
@trial_feature_enable
|
||||||
|
@get_app_model_with_trial
|
||||||
|
def get(self, app_model):
|
||||||
|
"""Retrieve app site info.
|
||||||
|
|
||||||
|
Returns the site configuration for the application including theme, icons, and text.
|
||||||
|
"""
|
||||||
|
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||||
|
|
||||||
|
if not site:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
assert app_model.tenant
|
||||||
|
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
class TrialAppParameterApi(Resource):
|
||||||
|
"""Resource for app variables."""
|
||||||
|
|
||||||
|
@trial_feature_enable
|
||||||
|
@get_app_model_with_trial
|
||||||
|
def get(self, app_model):
|
||||||
|
"""Retrieve app parameters."""
|
||||||
|
|
||||||
|
if app_model is None:
|
||||||
|
raise AppUnavailableError()
|
||||||
|
|
||||||
|
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
|
workflow = app_model.workflow
|
||||||
|
if workflow is None:
|
||||||
|
raise AppUnavailableError()
|
||||||
|
|
||||||
|
features_dict = workflow.features_dict
|
||||||
|
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||||
|
else:
|
||||||
|
app_model_config = app_model.app_model_config
|
||||||
|
if app_model_config is None:
|
||||||
|
raise AppUnavailableError()
|
||||||
|
|
||||||
|
features_dict = app_model_config.to_dict()
|
||||||
|
|
||||||
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
class AppApi(Resource):
|
||||||
|
@trial_feature_enable
|
||||||
|
@get_app_model_with_trial
|
||||||
|
@marshal_with(app_detail_fields_with_site)
|
||||||
|
def get(self, app_model):
|
||||||
|
"""Get app detail"""
|
||||||
|
|
||||||
|
app_service = AppService()
|
||||||
|
app_model = app_service.get_app(app_model)
|
||||||
|
|
||||||
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
|
class AppWorkflowApi(Resource):
|
||||||
|
@trial_feature_enable
|
||||||
|
@get_app_model_with_trial
|
||||||
|
@marshal_with(workflow_fields)
|
||||||
|
def get(self, app_model):
|
||||||
|
"""Get workflow detail"""
|
||||||
|
if not app_model.workflow_id:
|
||||||
|
raise AppUnavailableError()
|
||||||
|
|
||||||
|
workflow = (
|
||||||
|
db.session.query(Workflow)
|
||||||
|
.where(
|
||||||
|
Workflow.id == app_model.workflow_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetListApi(Resource):
|
||||||
|
@trial_feature_enable
|
||||||
|
@get_app_model_with_trial
|
||||||
|
def get(self, app_model):
|
||||||
|
page = request.args.get("page", default=1, type=int)
|
||||||
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
|
ids = request.args.getlist("ids")
|
||||||
|
|
||||||
|
tenant_id = app_model.tenant_id
|
||||||
|
if ids:
|
||||||
|
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
|
||||||
|
else:
|
||||||
|
raise NeedAddIdsError()
|
||||||
|
|
||||||
|
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
|
||||||
|
|
||||||
|
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
TrialMessageSuggestedQuestionApi,
|
||||||
|
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||||
|
endpoint="trial_app_suggested_question",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||||
|
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||||
|
|
||||||
|
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||||
|
|
||||||
|
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||||
|
|
||||||
|
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||||
|
|
||||||
|
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||||
|
|
||||||
|
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
|
||||||
|
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||||
|
|
||||||
|
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||||
|
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||||
|
|
@ -2,14 +2,15 @@ from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from flask import abort
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console.explore.error import AppAccessDeniedError
|
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import InstalledApp
|
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||||
|
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||||
|
|
||||||
|
if trial_app is None:
|
||||||
|
raise TrialAppNotAllowed()
|
||||||
|
app = trial_app.app
|
||||||
|
|
||||||
|
if app is None:
|
||||||
|
raise TrialAppNotAllowed()
|
||||||
|
|
||||||
|
account_trial_app_record = (
|
||||||
|
db.session.query(AccountTrialAppRecord)
|
||||||
|
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if account_trial_app_record:
|
||||||
|
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||||
|
raise TrialAppLimitExceeded()
|
||||||
|
|
||||||
|
return view(app, *args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
if view:
|
||||||
|
return decorator(view)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
features = FeatureService.get_system_features()
|
||||||
|
if not features.enable_trial_app:
|
||||||
|
abort(403, "Trial app feature is not enabled.")
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
features = FeatureService.get_system_features()
|
||||||
|
if not features.enable_explore_banner:
|
||||||
|
abort(403, "Explore banner feature is not enabled.")
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
class InstalledAppResource(Resource):
|
class InstalledAppResource(Resource):
|
||||||
# must be reversed if there are multiple decorators
|
# must be reversed if there are multiple decorators
|
||||||
|
|
||||||
|
|
@ -80,3 +136,13 @@ class InstalledAppResource(Resource):
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
login_required,
|
login_required,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TrialAppResource(Resource):
|
||||||
|
# must be reversed if there are multiple decorators
|
||||||
|
|
||||||
|
method_decorators = [
|
||||||
|
trial_app_required,
|
||||||
|
account_initialization_required,
|
||||||
|
login_required,
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
"""add table explore banner and trial
|
||||||
|
|
||||||
|
Revision ID: f9f6d18a37f9
|
||||||
|
Revises: 9e6fa5cbcd80
|
||||||
|
Create Date: 2026-01-017 11:10:18.079355
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'f9f6d18a37f9'
|
||||||
|
down_revision = '9e6fa5cbcd80'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('account_trial_app_records',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
|
||||||
|
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
|
||||||
|
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('exporle_banners',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('content', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('link', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('sort', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('trial_apps',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('trial_limit', sa.Integer(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
|
||||||
|
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
|
||||||
|
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('trial_app_tenant_id_idx')
|
||||||
|
batch_op.drop_index('trial_app_app_id_idx')
|
||||||
|
|
||||||
|
op.drop_table('trial_apps')
|
||||||
|
op.drop_table('exporle_banners')
|
||||||
|
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('account_trial_app_record_app_id_idx')
|
||||||
|
batch_op.drop_index('account_trial_app_record_account_id_idx')
|
||||||
|
|
||||||
|
op.drop_table('account_trial_app_records')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
@ -35,6 +35,7 @@ from .enums import (
|
||||||
WorkflowTriggerStatus,
|
WorkflowTriggerStatus,
|
||||||
)
|
)
|
||||||
from .model import (
|
from .model import (
|
||||||
|
AccountTrialAppRecord,
|
||||||
ApiRequest,
|
ApiRequest,
|
||||||
ApiToken,
|
ApiToken,
|
||||||
App,
|
App,
|
||||||
|
|
@ -47,6 +48,7 @@ from .model import (
|
||||||
DatasetRetrieverResource,
|
DatasetRetrieverResource,
|
||||||
DifySetup,
|
DifySetup,
|
||||||
EndUser,
|
EndUser,
|
||||||
|
ExporleBanner,
|
||||||
IconType,
|
IconType,
|
||||||
InstalledApp,
|
InstalledApp,
|
||||||
Message,
|
Message,
|
||||||
|
|
@ -62,6 +64,7 @@ from .model import (
|
||||||
TagBinding,
|
TagBinding,
|
||||||
TenantCreditPool,
|
TenantCreditPool,
|
||||||
TraceAppConfig,
|
TraceAppConfig,
|
||||||
|
TrialApp,
|
||||||
UploadFile,
|
UploadFile,
|
||||||
)
|
)
|
||||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||||
|
|
@ -114,6 +117,7 @@ __all__ = [
|
||||||
"Account",
|
"Account",
|
||||||
"AccountIntegrate",
|
"AccountIntegrate",
|
||||||
"AccountStatus",
|
"AccountStatus",
|
||||||
|
"AccountTrialAppRecord",
|
||||||
"ApiRequest",
|
"ApiRequest",
|
||||||
"ApiToken",
|
"ApiToken",
|
||||||
"ApiToolProvider",
|
"ApiToolProvider",
|
||||||
|
|
@ -150,6 +154,7 @@ __all__ = [
|
||||||
"DocumentSegment",
|
"DocumentSegment",
|
||||||
"Embedding",
|
"Embedding",
|
||||||
"EndUser",
|
"EndUser",
|
||||||
|
"ExporleBanner",
|
||||||
"ExternalKnowledgeApis",
|
"ExternalKnowledgeApis",
|
||||||
"ExternalKnowledgeBindings",
|
"ExternalKnowledgeBindings",
|
||||||
"IconType",
|
"IconType",
|
||||||
|
|
@ -188,6 +193,7 @@ __all__ = [
|
||||||
"ToolLabelBinding",
|
"ToolLabelBinding",
|
||||||
"ToolModelInvoke",
|
"ToolModelInvoke",
|
||||||
"TraceAppConfig",
|
"TraceAppConfig",
|
||||||
|
"TrialApp",
|
||||||
"TriggerOAuthSystemClient",
|
"TriggerOAuthSystemClient",
|
||||||
"TriggerOAuthTenantClient",
|
"TriggerOAuthTenantClient",
|
||||||
"TriggerSubscription",
|
"TriggerSubscription",
|
||||||
|
|
|
||||||
|
|
@ -603,6 +603,64 @@ class InstalledApp(TypeBase):
|
||||||
return tenant
|
return tenant
|
||||||
|
|
||||||
|
|
||||||
|
class TrialApp(Base):
|
||||||
|
__tablename__ = "trial_apps"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
|
||||||
|
sa.Index("trial_app_app_id_idx", "app_id"),
|
||||||
|
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
|
||||||
|
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
|
app_id = mapped_column(StringUUID, nullable=False)
|
||||||
|
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||||
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app(self) -> App | None:
|
||||||
|
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class AccountTrialAppRecord(Base):
|
||||||
|
__tablename__ = "account_trial_app_records"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
|
||||||
|
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
|
||||||
|
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||||
|
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||||
|
)
|
||||||
|
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
|
account_id = mapped_column(StringUUID, nullable=False)
|
||||||
|
app_id = mapped_column(StringUUID, nullable=False)
|
||||||
|
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app(self) -> App | None:
|
||||||
|
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||||
|
return app
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self) -> Account | None:
|
||||||
|
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
class ExporleBanner(Base):
|
||||||
|
__tablename__ = "exporle_banners"
|
||||||
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||||
|
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
|
content = mapped_column(sa.JSON, nullable=False)
|
||||||
|
link = mapped_column(String(255), nullable=False)
|
||||||
|
sort = mapped_column(sa.Integer, nullable=False)
|
||||||
|
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
|
||||||
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||||
|
|
||||||
|
|
||||||
class OAuthProviderApp(TypeBase):
|
class OAuthProviderApp(TypeBase):
|
||||||
"""
|
"""
|
||||||
Globally shared OAuth provider app information.
|
Globally shared OAuth provider app information.
|
||||||
|
|
|
||||||
|
|
@ -170,6 +170,8 @@ class SystemFeatureModel(BaseModel):
|
||||||
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
|
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
|
||||||
enable_change_email: bool = True
|
enable_change_email: bool = True
|
||||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||||
|
enable_trial_app: bool = False
|
||||||
|
enable_explore_banner: bool = False
|
||||||
|
|
||||||
|
|
||||||
class FeatureService:
|
class FeatureService:
|
||||||
|
|
@ -225,6 +227,8 @@ class FeatureService:
|
||||||
system_features.is_allow_register = dify_config.ALLOW_REGISTER
|
system_features.is_allow_register = dify_config.ALLOW_REGISTER
|
||||||
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
||||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||||
|
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
|
||||||
|
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _fulfill_params_from_env(cls, features: FeatureModel):
|
def _fulfill_params_from_env(cls, features: FeatureModel):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import AccountTrialAppRecord, TrialApp
|
||||||
|
from services.feature_service import FeatureService
|
||||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,6 +23,15 @@ class RecommendedAppService:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if FeatureService.get_system_features().enable_trial_app:
|
||||||
|
apps = result["recommended_apps"]
|
||||||
|
for app in apps:
|
||||||
|
app_id = app["app_id"]
|
||||||
|
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||||
|
if trial_app_model:
|
||||||
|
app["can_trial"] = True
|
||||||
|
else:
|
||||||
|
app["can_trial"] = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -32,4 +44,30 @@ class RecommendedAppService:
|
||||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||||
|
if FeatureService.get_system_features().enable_trial_app:
|
||||||
|
app_id = result["id"]
|
||||||
|
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||||
|
if trial_app_model:
|
||||||
|
result["can_trial"] = True
|
||||||
|
else:
|
||||||
|
result["can_trial"] = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_trial_app_record(cls, app_id: str, account_id: str):
|
||||||
|
"""
|
||||||
|
Add trial app record.
|
||||||
|
:param app_id: app id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
account_trial_app_record = (
|
||||||
|
db.session.query(AccountTrialAppRecord)
|
||||||
|
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if account_trial_app_record:
|
||||||
|
account_trial_app_record.count += 1
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
|
||||||
|
db.session.commit()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue