mirror of https://github.com/langgenius/dify.git
fix: enterprise API error handling and license enforcement (#33044)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
dd39fcd9bc
commit
977ed79ea0
|
|
@ -1,16 +1,45 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from flask import request
|
||||||
from opentelemetry.trace import get_current_span
|
from opentelemetry.trace import get_current_span
|
||||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from contexts.wrapper import RecyclableContextVar
|
from contexts.wrapper import RecyclableContextVar
|
||||||
|
from controllers.console.error import UnauthorizedAndForceLogout
|
||||||
from core.logging.context import init_request_context
|
from core.logging.context import init_request_context
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Console bootstrap APIs exempt from license check.
|
||||||
|
# Defined at module level to avoid per-request tuple construction.
|
||||||
|
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||||
|
# - setup: install/setup status check (AppInitializer)
|
||||||
|
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||||
|
# - login: auto-login after setup completion (InstallForm)
|
||||||
|
# - features: billing/plan features (ProviderContextProvider)
|
||||||
|
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||||
|
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||||
|
# - version: version check (AppContextProvider)
|
||||||
|
# - activate/check: invitation link validation (signin page)
|
||||||
|
# Without these exemptions, the signin page triggers location.reload()
|
||||||
|
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||||
|
_CONSOLE_EXEMPT_PREFIXES = (
|
||||||
|
"/console/api/system-features",
|
||||||
|
"/console/api/setup",
|
||||||
|
"/console/api/init",
|
||||||
|
"/console/api/login",
|
||||||
|
"/console/api/features",
|
||||||
|
"/console/api/account/profile",
|
||||||
|
"/console/api/workspaces/current",
|
||||||
|
"/console/api/version",
|
||||||
|
"/console/api/activate/check",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# Application Factory Function
|
# Application Factory Function
|
||||||
|
|
@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||||
init_request_context()
|
init_request_context()
|
||||||
RecyclableContextVar.increment_thread_recycles()
|
RecyclableContextVar.increment_thread_recycles()
|
||||||
|
|
||||||
|
# Enterprise license validation for API endpoints (both console and webapp)
|
||||||
|
# When license expires, block all API access except bootstrap endpoints needed
|
||||||
|
# for the frontend to load the license expiration page without infinite reloads.
|
||||||
|
if dify_config.ENTERPRISE_ENABLED:
|
||||||
|
is_console_api = request.path.startswith("/console/api/")
|
||||||
|
is_webapp_api = request.path.startswith("/api/")
|
||||||
|
|
||||||
|
if is_console_api or is_webapp_api:
|
||||||
|
if is_console_api:
|
||||||
|
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||||
|
else: # webapp API
|
||||||
|
is_exempt = request.path.startswith("/api/system-features")
|
||||||
|
|
||||||
|
if not is_exempt:
|
||||||
|
try:
|
||||||
|
# Check license status (cached — see EnterpriseService for TTL details)
|
||||||
|
license_status = EnterpriseService.get_cached_license_status()
|
||||||
|
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||||
|
raise UnauthorizedAndForceLogout(
|
||||||
|
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||||
|
)
|
||||||
|
if license_status is None:
|
||||||
|
raise UnauthorizedAndForceLogout(
|
||||||
|
"Unable to verify enterprise license. Please contact your administrator."
|
||||||
|
)
|
||||||
|
except UnauthorizedAndForceLogout:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to check enterprise license status")
|
||||||
|
raise UnauthorizedAndForceLogout(
|
||||||
|
"Unable to verify enterprise license. Please contact your administrator."
|
||||||
|
)
|
||||||
|
|
||||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||||
# Only adds headers when OTEL is enabled and has valid context
|
# Only adds headers when OTEL is enabled and has valid context
|
||||||
@dify_app.after_request
|
@dify_app.after_request
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,13 @@ from typing import Any
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from core.helper.trace_id_helper import generate_traceparent_header
|
from core.helper.trace_id_helper import generate_traceparent_header
|
||||||
|
from services.errors.enterprise import (
|
||||||
|
EnterpriseAPIBadRequestError,
|
||||||
|
EnterpriseAPIError,
|
||||||
|
EnterpriseAPIForbiddenError,
|
||||||
|
EnterpriseAPINotFoundError,
|
||||||
|
EnterpriseAPIUnauthorizedError,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -64,10 +71,51 @@ class BaseRequest:
|
||||||
request_kwargs["timeout"] = timeout
|
request_kwargs["timeout"] = timeout
|
||||||
|
|
||||||
response = client.request(method, url, **request_kwargs)
|
response = client.request(method, url, **request_kwargs)
|
||||||
if raise_for_status:
|
|
||||||
response.raise_for_status()
|
# Validate HTTP status and raise domain-specific errors
|
||||||
|
if not response.is_success:
|
||||||
|
cls._handle_error_response(response)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _handle_error_response(cls, response: httpx.Response) -> None:
|
||||||
|
"""
|
||||||
|
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
||||||
|
|
||||||
|
Attempts to extract error message from JSON response body,
|
||||||
|
falls back to status text if parsing fails.
|
||||||
|
"""
|
||||||
|
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
||||||
|
|
||||||
|
# Try to extract error message from JSON response
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
if isinstance(error_data, dict):
|
||||||
|
# Common error response formats:
|
||||||
|
# {"error": "...", "message": "..."}
|
||||||
|
# {"message": "..."}
|
||||||
|
# {"detail": "..."}
|
||||||
|
error_message = (
|
||||||
|
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# If JSON parsing fails, use the default message
|
||||||
|
logger.debug(
|
||||||
|
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raise specific error based on status code
|
||||||
|
if response.status_code == 400:
|
||||||
|
raise EnterpriseAPIBadRequestError(error_message)
|
||||||
|
elif response.status_code == 401:
|
||||||
|
raise EnterpriseAPIUnauthorizedError(error_message)
|
||||||
|
elif response.status_code == 403:
|
||||||
|
raise EnterpriseAPIForbiddenError(error_message)
|
||||||
|
elif response.status_code == 404:
|
||||||
|
raise EnterpriseAPINotFoundError(error_message)
|
||||||
|
else:
|
||||||
|
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseRequest(BaseRequest):
|
class EnterpriseRequest(BaseRequest):
|
||||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,26 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from services.enterprise.base import EnterpriseRequest
|
from services.enterprise.base import EnterpriseRequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||||
|
# License status cache configuration
|
||||||
|
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
|
||||||
|
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
||||||
|
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||||
|
|
||||||
|
|
||||||
class WebAppSettings(BaseModel):
|
class WebAppSettings(BaseModel):
|
||||||
|
|
@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
|
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
|
||||||
if self.joined and not self.workspace_id:
|
if self.joined and not self.workspace_id:
|
||||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||||
return self
|
return self
|
||||||
|
|
@ -115,7 +126,6 @@ class EnterpriseService:
|
||||||
"/default-workspace/members",
|
"/default-workspace/members",
|
||||||
json={"account_id": account_id},
|
json={"account_id": account_id},
|
||||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||||
raise_for_status=True,
|
|
||||||
)
|
)
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||||
|
|
@ -223,3 +233,64 @@ class EnterpriseService:
|
||||||
|
|
||||||
params = {"appId": app_id}
|
params = {"appId": app_id}
|
||||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||||
|
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||||
|
|
||||||
|
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
|
||||||
|
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
|
||||||
|
balances prompt license-fix detection against DoS mitigation — without
|
||||||
|
caching, every request on an expired license would hit the enterprise API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
|
||||||
|
"""
|
||||||
|
if not dify_config.ENTERPRISE_ENABLED:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached = cls._read_cached_license_status()
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
return cls._fetch_and_cache_license_status()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _read_cached_license_status(cls) -> LicenseStatus | None:
|
||||||
|
"""Read license status from Redis cache, returning None on miss or failure."""
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||||
|
if raw:
|
||||||
|
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||||
|
return LicenseStatus(value)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to read license status from cache", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
|
||||||
|
"""Fetch license status from enterprise API and cache the result."""
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = cls.get_info()
|
||||||
|
license_info = info.get("License")
|
||||||
|
if not license_info:
|
||||||
|
return None
|
||||||
|
|
||||||
|
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||||
|
ttl = (
|
||||||
|
VALID_LICENSE_CACHE_TTL
|
||||||
|
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
|
||||||
|
else INVALID_LICENSE_CACHE_TTL
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to cache license status", exc_info=True)
|
||||||
|
return status
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to fetch enterprise license status", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,6 @@ class PluginManagerService:
|
||||||
"POST",
|
"POST",
|
||||||
"/pre-uninstall-plugin",
|
"/pre-uninstall-plugin",
|
||||||
json=body.model_dump(),
|
json=body.model_dump(),
|
||||||
raise_for_status=True,
|
|
||||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from . import (
|
||||||
conversation,
|
conversation,
|
||||||
dataset,
|
dataset,
|
||||||
document,
|
document,
|
||||||
|
enterprise,
|
||||||
file,
|
file,
|
||||||
index,
|
index,
|
||||||
message,
|
message,
|
||||||
|
|
@ -21,6 +22,7 @@ __all__ = [
|
||||||
"conversation",
|
"conversation",
|
||||||
"dataset",
|
"dataset",
|
||||||
"document",
|
"document",
|
||||||
|
"enterprise",
|
||||||
"file",
|
"file",
|
||||||
"index",
|
"index",
|
||||||
"message",
|
"message",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""Enterprise service errors."""
|
||||||
|
|
||||||
|
from services.errors.base import BaseServiceError
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseServiceError(BaseServiceError):
|
||||||
|
"""Base exception for enterprise service errors."""
|
||||||
|
|
||||||
|
def __init__(self, description: str | None = None, status_code: int | None = None):
|
||||||
|
super().__init__(description)
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseAPIError(EnterpriseServiceError):
|
||||||
|
"""Generic enterprise API error (non-2xx response)."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseAPINotFoundError(EnterpriseServiceError):
|
||||||
|
"""Enterprise API returned 404 Not Found."""
|
||||||
|
|
||||||
|
def __init__(self, description: str | None = None):
|
||||||
|
super().__init__(description, status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
|
||||||
|
"""Enterprise API returned 403 Forbidden."""
|
||||||
|
|
||||||
|
def __init__(self, description: str | None = None):
|
||||||
|
super().__init__(description, status_code=403)
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
|
||||||
|
"""Enterprise API returned 401 Unauthorized."""
|
||||||
|
|
||||||
|
def __init__(self, description: str | None = None):
|
||||||
|
super().__init__(description, status_code=401)
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
|
||||||
|
"""Enterprise API returned 400 Bad Request."""
|
||||||
|
|
||||||
|
def __init__(self, description: str | None = None):
|
||||||
|
super().__init__(description, status_code=400)
|
||||||
|
|
@ -379,14 +379,19 @@ class FeatureService:
|
||||||
)
|
)
|
||||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||||
|
|
||||||
if is_authenticated and (license_info := enterprise_info.get("License")):
|
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
|
||||||
|
# so the login page can detect an expired/inactive license after force-logout.
|
||||||
|
# All other license details (expiry date, workspace usage) remain auth-gated.
|
||||||
|
# This behavior reflects prior internal review of information-leakage risks.
|
||||||
|
if license_info := enterprise_info.get("License"):
|
||||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||||
features.license.expired_at = license_info.get("expiredAt", "")
|
|
||||||
|
|
||||||
if workspaces_info := license_info.get("workspaces"):
|
if is_authenticated:
|
||||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
features.license.expired_at = license_info.get("expiredAt", "")
|
||||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
if workspaces_info := license_info.get("workspaces"):
|
||||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||||
|
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||||
|
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||||
|
|
||||||
if "PluginInstallationPermission" in enterprise_info:
|
if "PluginInstallationPermission" in enterprise_info:
|
||||||
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
|
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
|
||||||
|
|
|
||||||
|
|
@ -358,10 +358,9 @@ class TestFeatureService:
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, SystemFeatureModel)
|
assert isinstance(result, SystemFeatureModel)
|
||||||
|
|
||||||
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
|
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
|
||||||
# Ensure only essential UI flags are returned to unauthenticated clients
|
# Detailed license info (expiry, workspaces) remains auth-gated.
|
||||||
# to keep the payload lightweight and adhere to architectural boundaries.
|
assert result.license.status == LicenseStatus.ACTIVE
|
||||||
assert result.license.status == LicenseStatus.NONE
|
|
||||||
assert result.license.expired_at == ""
|
assert result.license.expired_at == ""
|
||||||
assert result.license.workspaces.enabled is False
|
assert result.license.workspaces.enabled is False
|
||||||
assert result.license.workspaces.limit == 0
|
assert result.license.workspaces.limit == 0
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
"""Unit tests for enterprise service integrations.
|
"""Unit tests for enterprise service integrations.
|
||||||
|
|
||||||
This module covers the enterprise-only default workspace auto-join behavior:
|
Covers:
|
||||||
- Enterprise mode disabled: no external calls
|
- Default workspace auto-join behavior
|
||||||
- Successful join / skipped join: no errors
|
- License status caching (get_cached_license_status)
|
||||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
@ -11,6 +10,9 @@ from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from services.enterprise.enterprise_service import (
|
from services.enterprise.enterprise_service import (
|
||||||
|
INVALID_LICENSE_CACHE_TTL,
|
||||||
|
LICENSE_STATUS_CACHE_KEY,
|
||||||
|
VALID_LICENSE_CACHE_TTL,
|
||||||
DefaultWorkspaceJoinResult,
|
DefaultWorkspaceJoinResult,
|
||||||
EnterpriseService,
|
EnterpriseService,
|
||||||
try_join_default_workspace,
|
try_join_default_workspace,
|
||||||
|
|
@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
|
||||||
"/default-workspace/members",
|
"/default-workspace/members",
|
||||||
json={"account_id": account_id},
|
json={"account_id": account_id},
|
||||||
timeout=1.0,
|
timeout=1.0,
|
||||||
raise_for_status=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||||
|
|
@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
|
||||||
|
|
||||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||||
try_join_default_workspace("not-a-uuid")
|
try_join_default_workspace("not-a-uuid")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_cached_license_status
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_EE_SVC = "services.enterprise.enterprise_service"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCachedLicenseStatus:
|
||||||
|
"""Tests for EnterpriseService.get_cached_license_status."""
|
||||||
|
|
||||||
|
def test_returns_none_when_enterprise_disabled(self):
|
||||||
|
with patch(f"{_EE_SVC}.dify_config") as mock_config:
|
||||||
|
mock_config.ENTERPRISE_ENABLED = False
|
||||||
|
|
||||||
|
assert EnterpriseService.get_cached_license_status() is None
|
||||||
|
|
||||||
|
def test_cache_hit_returns_license_status_enum(self):
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||||
|
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||||
|
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_redis.get.return_value = b"active"
|
||||||
|
|
||||||
|
result = EnterpriseService.get_cached_license_status()
|
||||||
|
|
||||||
|
assert result == LicenseStatus.ACTIVE
|
||||||
|
assert isinstance(result, LicenseStatus)
|
||||||
|
mock_get_info.assert_not_called()
|
||||||
|
|
||||||
|
def test_cache_miss_fetches_api_and_caches_valid_status(self):
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||||
|
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||||
|
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||||
|
|
||||||
|
result = EnterpriseService.get_cached_license_status()
|
||||||
|
|
||||||
|
assert result == LicenseStatus.ACTIVE
|
||||||
|
mock_redis.setex.assert_called_once_with(
|
||||||
|
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||||
|
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||||
|
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_get_info.return_value = {"License": {"status": "expired"}}
|
||||||
|
|
||||||
|
result = EnterpriseService.get_cached_license_status()
|
||||||
|
|
||||||
|
assert result == LicenseStatus.EXPIRED
|
||||||
|
mock_redis.setex.assert_called_once_with(
|
||||||
|
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_redis_read_failure_falls_through_to_api(self):
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||||
|
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||||
|
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_redis.get.side_effect = ConnectionError("redis down")
|
||||||
|
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||||
|
|
||||||
|
result = EnterpriseService.get_cached_license_status()
|
||||||
|
|
||||||
|
assert result == LicenseStatus.ACTIVE
|
||||||
|
mock_get_info.assert_called_once()
|
||||||
|
|
||||||
|
def test_redis_write_failure_still_returns_status(self):
|
||||||
|
from services.feature_service import LicenseStatus
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||||
|
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||||
|
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_redis.setex.side_effect = ConnectionError("redis down")
|
||||||
|
mock_get_info.return_value = {"License": {"status": "expiring"}}
|
||||||
|
|
||||||
|
result = EnterpriseService.get_cached_license_status()
|
||||||
|
|
||||||
|
assert result == LicenseStatus.EXPIRING
|
||||||
|
|
||||||
|
def test_api_failure_returns_none(self):
|
||||||
|
with (
|
||||||
|
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||||
|
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||||
|
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_get_info.side_effect = Exception("network failure")
|
||||||
|
|
||||||
|
assert EnterpriseService.get_cached_license_status() is None
|
||||||
|
|
||||||
|
def test_api_returns_no_license_info(self):
|
||||||
|
with (
|
||||||
|
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||||
|
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||||
|
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
mock_get_info.return_value = {} # no "License" key
|
||||||
|
|
||||||
|
assert EnterpriseService.get_cached_license_status() is None
|
||||||
|
mock_redis.setex.assert_not_called()
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ class TestTryPreUninstallPlugin:
|
||||||
"POST",
|
"POST",
|
||||||
"/pre-uninstall-plugin",
|
"/pre-uninstall-plugin",
|
||||||
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
|
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
|
||||||
raise_for_status=True,
|
|
||||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -62,7 +61,6 @@ class TestTryPreUninstallPlugin:
|
||||||
"POST",
|
"POST",
|
||||||
"/pre-uninstall-plugin",
|
"/pre-uninstall-plugin",
|
||||||
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
|
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
|
||||||
raise_for_status=True,
|
|
||||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
mock_logger.exception.assert_called_once()
|
mock_logger.exception.assert_called_once()
|
||||||
|
|
@ -87,7 +85,6 @@ class TestTryPreUninstallPlugin:
|
||||||
"POST",
|
"POST",
|
||||||
"/pre-uninstall-plugin",
|
"/pre-uninstall-plugin",
|
||||||
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
|
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
|
||||||
raise_for_status=True,
|
|
||||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
mock_logger.exception.assert_called_once()
|
mock_logger.exception.assert_called_once()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue