fix: enterprise API error handling and license enforcement (#33044)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Xiyuan Chen 2026-03-15 20:59:41 -07:00 committed by GitHub
parent dd39fcd9bc
commit 977ed79ea0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 383 additions and 23 deletions

View File

@ -1,16 +1,45 @@
import logging import logging
import time import time
from flask import request
from opentelemetry.trace import get_current_span from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config from configs import dify_config
from contexts.wrapper import RecyclableContextVar from contexts.wrapper import RecyclableContextVar
from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context from core.logging.context import init_request_context
from dify_app import DifyApp from dify_app import DifyApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Console bootstrap APIs exempt from license check.
# Defined at module level to avoid per-request tuple construction.
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
# - setup: install/setup status check (AppInitializer)
# - init: init password validation for fresh install (InitPasswordPopup)
# - login: auto-login after setup completion (InstallForm)
# - features: billing/plan features (ProviderContextProvider)
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
# - workspaces/current: workspace + model providers (AppContextProvider)
# - version: version check (AppContextProvider)
# - activate/check: invitation link validation (signin page)
# Without these exemptions, the signin page triggers location.reload()
# on unauthorized_and_force_logout, causing an infinite loop.
_CONSOLE_EXEMPT_PREFIXES = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/init",
"/console/api/login",
"/console/api/features",
"/console/api/account/profile",
"/console/api/workspaces/current",
"/console/api/version",
"/console/api/activate/check",
)
# ---------------------------- # ----------------------------
# Application Factory Function # Application Factory Function
@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
init_request_context() init_request_context()
RecyclableContextVar.increment_thread_recycles() RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for API endpoints (both console and webapp)
# When license expires, block all API access except bootstrap endpoints needed
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/")
if is_console_api or is_webapp_api:
if is_console_api:
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
else: # webapp API
is_exempt = request.path.startswith("/api/system-features")
if not is_exempt:
try:
# Check license status (cached — see EnterpriseService for TTL details)
license_status = EnterpriseService.get_cached_license_status()
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
raise UnauthorizedAndForceLogout(
f"Enterprise license is {license_status}. Please contact your administrator."
)
if license_status is None:
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
except UnauthorizedAndForceLogout:
raise
except Exception:
logger.exception("Failed to check enterprise license status")
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
# add after request hook for injecting trace headers from OpenTelemetry span context # add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context # Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request @dify_app.after_request

View File

@ -6,6 +6,13 @@ from typing import Any
import httpx import httpx
from core.helper.trace_id_helper import generate_traceparent_header from core.helper.trace_id_helper import generate_traceparent_header
from services.errors.enterprise import (
EnterpriseAPIBadRequestError,
EnterpriseAPIError,
EnterpriseAPIForbiddenError,
EnterpriseAPINotFoundError,
EnterpriseAPIUnauthorizedError,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,10 +71,51 @@ class BaseRequest:
request_kwargs["timeout"] = timeout request_kwargs["timeout"] = timeout
response = client.request(method, url, **request_kwargs) response = client.request(method, url, **request_kwargs)
if raise_for_status:
response.raise_for_status() # Validate HTTP status and raise domain-specific errors
if not response.is_success:
cls._handle_error_response(response)
return response.json() return response.json()
@classmethod
def _handle_error_response(cls, response: httpx.Response) -> None:
"""
Handle non-2xx HTTP responses by raising appropriate domain errors.
Attempts to extract error message from JSON response body,
falls back to status text if parsing fails.
"""
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
# Try to extract error message from JSON response
try:
error_data = response.json()
if isinstance(error_data, dict):
# Common error response formats:
# {"error": "...", "message": "..."}
# {"message": "..."}
# {"detail": "..."}
error_message = (
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
)
except Exception:
# If JSON parsing fails, use the default message
logger.debug(
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
)
# Raise specific error based on status code
if response.status_code == 400:
raise EnterpriseAPIBadRequestError(error_message)
elif response.status_code == 401:
raise EnterpriseAPIUnauthorizedError(error_message)
elif response.status_code == 403:
raise EnterpriseAPIForbiddenError(error_message)
elif response.status_code == 404:
raise EnterpriseAPINotFoundError(error_message)
else:
raise EnterpriseAPIError(error_message, status_code=response.status_code)
class EnterpriseRequest(BaseRequest): class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")

View File

@ -1,15 +1,26 @@
from __future__ import annotations
import logging import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import EnterpriseRequest from services.enterprise.base import EnterpriseRequest
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
# License status cache configuration
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
class WebAppSettings(BaseModel): class WebAppSettings(BaseModel):
@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True) model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after") @model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult": def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
if self.joined and not self.workspace_id: if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True") raise ValueError("workspace_id must be non-empty when joined is True")
return self return self
@ -115,7 +126,6 @@ class EnterpriseService:
"/default-workspace/members", "/default-workspace/members",
json={"account_id": account_id}, json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS, timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
) )
if not isinstance(data, dict): if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API") raise ValueError("Invalid response format from enterprise default workspace API")
@ -223,3 +233,64 @@ class EnterpriseService:
params = {"appId": app_id} params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params) EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls.
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
balances prompt license-fix detection against DoS mitigation without
caching, every request on an expired license would hit the enterprise API.
Returns:
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
"""
if not dify_config.ENTERPRISE_ENABLED:
return None
cached = cls._read_cached_license_status()
if cached is not None:
return cached
return cls._fetch_and_cache_license_status()
@classmethod
def _read_cached_license_status(cls) -> LicenseStatus | None:
"""Read license status from Redis cache, returning None on miss or failure."""
from services.feature_service import LicenseStatus
try:
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if raw:
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
return LicenseStatus(value)
except Exception:
logger.debug("Failed to read license status from cache", exc_info=True)
return None
@classmethod
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
"""Fetch license status from enterprise API and cache the result."""
from services.feature_service import LicenseStatus
try:
info = cls.get_info()
license_info = info.get("License")
if not license_info:
return None
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
ttl = (
VALID_LICENSE_CACHE_TTL
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
else INVALID_LICENSE_CACHE_TTL
)
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
except Exception:
logger.debug("Failed to cache license status", exc_info=True)
return status
except Exception:
logger.debug("Failed to fetch enterprise license status", exc_info=True)
return None

View File

@ -70,7 +70,6 @@ class PluginManagerService:
"POST", "POST",
"/pre-uninstall-plugin", "/pre-uninstall-plugin",
json=body.model_dump(), json=body.model_dump(),
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
) )
except Exception: except Exception:

View File

@ -7,6 +7,7 @@ from . import (
conversation, conversation,
dataset, dataset,
document, document,
enterprise,
file, file,
index, index,
message, message,
@ -21,6 +22,7 @@ __all__ = [
"conversation", "conversation",
"dataset", "dataset",
"document", "document",
"enterprise",
"file", "file",
"index", "index",
"message", "message",

View File

@ -0,0 +1,45 @@
"""Enterprise service errors."""
from services.errors.base import BaseServiceError
class EnterpriseServiceError(BaseServiceError):
"""Base exception for enterprise service errors."""
def __init__(self, description: str | None = None, status_code: int | None = None):
super().__init__(description)
self.status_code = status_code
class EnterpriseAPIError(EnterpriseServiceError):
"""Generic enterprise API error (non-2xx response)."""
pass
class EnterpriseAPINotFoundError(EnterpriseServiceError):
"""Enterprise API returned 404 Not Found."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=404)
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
"""Enterprise API returned 403 Forbidden."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=403)
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
"""Enterprise API returned 401 Unauthorized."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=401)
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
"""Enterprise API returned 400 Bad Request."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=400)

View File

@ -379,14 +379,19 @@ class FeatureService:
) )
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
if is_authenticated and (license_info := enterprise_info.get("License")): # SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
# so the login page can detect an expired/inactive license after force-logout.
# All other license details (expiry date, workspace usage) remain auth-gated.
# This behavior reflects prior internal review of information-leakage risks.
if license_info := enterprise_info.get("License"):
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"): if is_authenticated:
features.license.workspaces.enabled = workspaces_info.get("enabled", False) features.license.expired_at = license_info.get("expiredAt", "")
features.license.workspaces.limit = workspaces_info.get("limit", 0) if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.size = workspaces_info.get("used", 0) features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info: if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"] plugin_installation_info = enterprise_info["PluginInstallationPermission"]

View File

@ -358,10 +358,9 @@ class TestFeatureService:
assert result is not None assert result is not None
assert isinstance(result, SystemFeatureModel) assert isinstance(result, SystemFeatureModel)
# --- 1. Verify Response Payload Optimization (Data Minimization) --- # --- 1. Verify only license *status* is exposed to unauthenticated clients ---
# Ensure only essential UI flags are returned to unauthenticated clients # Detailed license info (expiry, workspaces) remains auth-gated.
# to keep the payload lightweight and adhere to architectural boundaries. assert result.license.status == LicenseStatus.ACTIVE
assert result.license.status == LicenseStatus.NONE
assert result.license.expired_at == "" assert result.license.expired_at == ""
assert result.license.workspaces.enabled is False assert result.license.workspaces.enabled is False
assert result.license.workspaces.limit == 0 assert result.license.workspaces.limit == 0

View File

@ -1,9 +1,8 @@
"""Unit tests for enterprise service integrations. """Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior: Covers:
- Enterprise mode disabled: no external calls - Default workspace auto-join behavior
- Successful join / skipped join: no errors - License status caching (get_cached_license_status)
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
""" """
from unittest.mock import patch from unittest.mock import patch
@ -11,6 +10,9 @@ from unittest.mock import patch
import pytest import pytest
from services.enterprise.enterprise_service import ( from services.enterprise.enterprise_service import (
INVALID_LICENSE_CACHE_TTL,
LICENSE_STATUS_CACHE_KEY,
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult, DefaultWorkspaceJoinResult,
EnterpriseService, EnterpriseService,
try_join_default_workspace, try_join_default_workspace,
@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
"/default-workspace/members", "/default-workspace/members",
json={"account_id": account_id}, json={"account_id": account_id},
timeout=1.0, timeout=1.0,
raise_for_status=True,
) )
def test_join_default_workspace_invalid_response_format_raises(self): def test_join_default_workspace_invalid_response_format_raises(self):
@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
# Should not raise even though UUID parsing fails inside join_default_workspace # Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid") try_join_default_workspace("not-a-uuid")
# ---------------------------------------------------------------------------
# get_cached_license_status
# ---------------------------------------------------------------------------
_EE_SVC = "services.enterprise.enterprise_service"
class TestGetCachedLicenseStatus:
"""Tests for EnterpriseService.get_cached_license_status."""
def test_returns_none_when_enterprise_disabled(self):
with patch(f"{_EE_SVC}.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False
assert EnterpriseService.get_cached_license_status() is None
def test_cache_hit_returns_license_status_enum(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = b"active"
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
assert isinstance(result, LicenseStatus)
mock_get_info.assert_not_called()
def test_cache_miss_fetches_api_and_caches_valid_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
)
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "expired"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRED
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
)
def test_redis_read_failure_falls_through_to_api(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_get_info.assert_called_once()
def test_redis_write_failure_still_returns_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_redis.setex.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "expiring"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRING
def test_api_failure_returns_none(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.side_effect = Exception("network failure")
assert EnterpriseService.get_cached_license_status() is None
def test_api_returns_no_license_info(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {} # no "License" key
assert EnterpriseService.get_cached_license_status() is None
mock_redis.setex.assert_not_called()

View File

@ -34,7 +34,6 @@ class TestTryPreUninstallPlugin:
"POST", "POST",
"/pre-uninstall-plugin", "/pre-uninstall-plugin",
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"}, json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
) )
@ -62,7 +61,6 @@ class TestTryPreUninstallPlugin:
"POST", "POST",
"/pre-uninstall-plugin", "/pre-uninstall-plugin",
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"}, json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
) )
mock_logger.exception.assert_called_once() mock_logger.exception.assert_called_once()
@ -87,7 +85,6 @@ class TestTryPreUninstallPlugin:
"POST", "POST",
"/pre-uninstall-plugin", "/pre-uninstall-plugin",
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"}, json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
) )
mock_logger.exception.assert_called_once() mock_logger.exception.assert_called_once()