dify/api/controllers/service_api/wraps.py

369 lines
15 KiB
Python
Raw Normal View History

import logging
2025-03-10 11:50:11 +00:00
import time
from collections.abc import Callable
from enum import StrEnum, auto
2023-05-15 00:51:32 +00:00
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
2023-05-15 00:51:32 +00:00
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
2025-03-10 11:50:11 +00:00
from extensions.ext_redis import redis_client
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
logger = logging.getLogger(__name__)
2023-05-15 00:51:32 +00:00
class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
"""
QUERY = auto()
JSON = auto()
FORM = auto()
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
@overload
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
@overload
def validate_app_token(
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def validate_app_token(
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
api_token = validate_and_get_api_token("app")
2023-05-15 00:51:32 +00:00
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
2023-05-15 00:51:32 +00:00
if not app_model:
raise Forbidden("The app no longer exists.")
2023-05-15 00:51:32 +00:00
if app_model.status != "normal":
raise Forbidden("The app's status is abnormal.")
2023-05-15 00:51:32 +00:00
if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.")
2023-05-15 00:51:32 +00:00
tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
2024-12-24 10:38:51 +00:00
if tenant is None:
raise ValueError("Tenant does not exist.")
2024-04-18 09:33:32 +00:00
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")
2024-04-18 09:33:32 +00:00
kwargs["app_model"] = app_model
2023-05-15 00:51:32 +00:00
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
2026-02-02 09:12:03 +00:00
user_id = None
match fetch_user_arg.fetch_from:
case WhereisUserArg.QUERY:
user_id = request.args.get("user")
case WhereisUserArg.JSON:
user_id = request.get_json().get("user")
case WhereisUserArg.FORM:
user_id = request.form.get("user")
2023-05-15 00:51:32 +00:00
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")
if user_id:
user_id = str(user_id)
end_user = EndUserService.get_or_create_end_user(app_model, user_id)
kwargs["end_user"] = end_user
# Set EndUser as current logged-in user for flask_login.current_user
2026-03-22 13:21:01 +00:00
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore[attr-defined]
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore[attr-defined]
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
Tenant.id == app_model.tenant_id,
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
if tenant_owner_info:
tenant_model, account = tenant_owner_info
account.current_tenant = tenant_model
2026-03-22 13:21:01 +00:00
current_app.login_manager._update_request_context_with_user(account) # type: ignore[attr-defined]
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore[attr-defined]
else:
raise Unauthorized("Tenant owner account not found or tenant is not active.")
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
2023-05-15 00:51:32 +00:00
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
vector_space = features.vector_space
2024-03-03 04:45:06 +00:00
documents_upload_quota = features.documents_upload_quota
if resource == "members" and 0 < members.limit <= members.size:
raise Forbidden("The number of members has reached the limit of your subscription.")
elif resource == "apps" and 0 < apps.limit <= apps.size:
raise Forbidden("The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Forbidden("The number of documents has reached the limit of your subscription.")
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
return decorated
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
raise Forbidden(
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
2025-03-10 11:50:11 +00:00
return decorated
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view: Callable[P, R]):
2025-03-10 11:50:11 +00:00
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
2025-03-10 11:50:11 +00:00
api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{api_token.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=api_token.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
raise Forbidden(
"Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
@overload
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
@overload
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
def validate_dataset_token(
view: Callable[Concatenate[T, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
@wraps(view_func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.first()
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
2026-03-22 13:21:01 +00:00
current_app.login_manager._update_request_context_with_user(account) # type: ignore[attr-defined]
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore[attr-defined]
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
2023-05-15 00:51:32 +00:00
return decorated
if view:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
def validate_and_get_api_token(scope: str | None = None):
2023-05-15 00:51:32 +00:00
"""
Validate and get API token with Redis caching.
This function uses a two-tier approach:
1. First checks Redis cache for the token
2. If not cached, queries database and caches the result
The last_used_at field is updated asynchronously via Celery task
to avoid blocking the request.
2023-05-15 00:51:32 +00:00
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
2023-05-15 00:51:32 +00:00
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
2023-05-15 00:51:32 +00:00
# Try to get token from cache first
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token validation served from cache for scope: %s", scope)
# Record usage in Redis for later batch update (no Celery task per request)
record_token_usage(auth_token, scope)
return cast(ApiToken, cached_token)
# Cache miss - use Redis lock for single-flight mode
# This ensures only one request queries DB for the same token concurrently
return fetch_token_with_single_flight(auth_token, scope)
2023-05-15 00:51:32 +00:00
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
if not dataset:
raise NotFound("Dataset not found.")
return dataset