This commit is contained in:
wangxiaolei 2026-03-24 06:23:22 +00:00 committed by GitHub
commit 0418e7c371
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 479 additions and 11 deletions

View File

@ -29,6 +29,14 @@ class SecurityConfig(BaseSettings):
default="",
)
SESSION_REVOCATION_STORAGE: str = Field(
description=(
"Session revocation storage backend for JWT jti revocation checks. "
"Options: 'null' (default, disabled) or 'redis' (use configured Redis)."
),
default="null",
)
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in minutes for which a password reset token remains valid",
default=5,

View File

@ -37,6 +37,7 @@ from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
@ -157,7 +158,9 @@ class LogoutApi(Resource):
if isinstance(account, flask_login.AnonymousUserMixin):
response = make_response({"result": "success"})
else:
AccountService.logout(account=account)
# Extract access_token from request to revoke it
access_token = extract_access_token(request)
AccountService.logout(account=account, access_token=access_token)
flask_login.logout_user()
response = make_response({"result": "success"})

View File

@ -1,19 +1,65 @@
import time
import uuid
from datetime import UTC, datetime
import jwt
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from libs.session_revocation_storage import get_session_revocation_storage
def _get_blacklist_key(jti: str) -> str:
return f"passport:blacklist:jti:{jti}"
class PassportService:
def __init__(self):
self.sk = dify_config.SECRET_KEY
@classmethod
def _get_blacklist_key(cls, jti: str) -> str:
"""Instance-accessible helper for tests and internal use."""
return _get_blacklist_key(jti)
def issue(self, payload):
return jwt.encode(payload, self.sk, algorithm="HS256")
# Add jti (JWT ID) if not present for token revocation support
payload_to_encode = dict(payload)
if "jti" not in payload_to_encode:
payload_to_encode["jti"] = str(uuid.uuid4())
return jwt.encode(payload_to_encode, self.sk, algorithm="HS256")
@classmethod
def revoke(cls, token: str) -> bool:
try:
payload = jwt.decode(token, options={"verify_signature": False})
except jwt.PyJWTError:
# Invalid/garbled token: treat as non-revocable
return False
token_id = payload.get("jti") or token
exp = payload.get("exp")
if not exp:
return False
ttl = int(exp - time.time())
if ttl <= 0:
return False
storage = get_session_revocation_storage()
storage.revoke(token_id, datetime.fromtimestamp(exp, tz=UTC))
return True
def verify(self, token):
"""Verify a JWT and then enforce revocation via SessionRevocationStorage.
The signature and standard claims are verified first to avoid any processing
of untrusted data for invalid tokens. After successful verification we consult
the configured revocation storage using the token's `jti` when present.
"""
# 1) Verify signature/claims first
try:
return jwt.decode(token, self.sk, algorithms=["HS256"])
verified_payload = jwt.decode(token, self.sk, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise Unauthorized("Token has expired.")
except jwt.InvalidSignatureError:
@ -22,3 +68,11 @@ class PassportService:
raise Unauthorized("Invalid token.")
except jwt.PyJWTError: # Catch-all for other JWT errors
raise Unauthorized("Invalid token.")
# 2) Enforce revocation using storage (supports old tokens without jti)
storage = get_session_revocation_storage()
token_id = verified_payload.get("jti") or token
if storage.is_revoked(token_id):
raise Unauthorized("Token has been revoked.")
return verified_payload

View File

@ -0,0 +1,73 @@
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Protocol, runtime_checkable
from configs import dify_config
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@runtime_checkable
class SessionRevocationStorage(Protocol):
def revoke(self, token_id: str, expiration_time: datetime) -> None: ...
def is_revoked(self, token_id: str) -> bool: ...
def expunge(self) -> None: ...
class NullSessionRevocationStorage(SessionRevocationStorage):
def revoke(self, token_id: str, expiration_time: datetime) -> None:
return None
def is_revoked(self, token_id: str) -> bool:
return False
def expunge(self) -> None:
return None
class RedisSessionRevocationStorage(SessionRevocationStorage):
def __init__(self, key_prefix: str = "passport:blacklist:jti:") -> None:
self.key_prefix = key_prefix
def _key(self, token_id: str) -> str:
return f"{self.key_prefix}{token_id}"
def revoke(self, token_id: str, expiration_time: datetime) -> None:
# Compute remaining lifetime in seconds
now_ts = datetime.now(UTC).timestamp()
ttl = int(expiration_time.timestamp() - now_ts)
if ttl <= 0:
return None
redis_client.setex(self._key(token_id), ttl, b"1")
return None
def is_revoked(self, token_id: str) -> bool:
return bool(redis_client.exists(self._key(token_id)))
def expunge(self) -> None:
return None
_singleton: SessionRevocationStorage | None = None
def get_session_revocation_storage() -> SessionRevocationStorage:
global _singleton
if _singleton is not None:
return _singleton
backend = (dify_config.SESSION_REVOCATION_STORAGE or "null").strip().lower()
if backend in ("", "null", "disabled", "off"):
_singleton = NullSessionRevocationStorage()
elif backend == "redis":
_singleton = RedisSessionRevocationStorage()
else:
logger.warning(
"Unknown SESSION_REVOCATION_STORAGE '%s'; falling back to 'null' (disabled).",
backend,
)
_singleton = NullSessionRevocationStorage()
return _singleton

View File

@ -440,7 +440,12 @@ class AccountService:
return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token)
@staticmethod
def logout(*, account: Account):
def logout(*, account: Account, access_token: str | None = None):
# Revoke access_token if provided
if access_token:
PassportService.revoke(access_token)
# Delete refresh_token
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)

View File

@ -462,7 +462,7 @@ class TestLogoutApi:
response = logout_api.post()
# Assert
mock_service_logout.assert_called_once_with(account=mock_account)
mock_service_logout.assert_called_once_with(account=mock_account, access_token=None)
mock_logout_user.assert_called_once()
assert response.json["result"] == "success"
@ -493,3 +493,135 @@ class TestLogoutApi:
# Assert
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
@patch("controllers.console.auth.login.extract_access_token")
def test_logout_with_access_token_in_cookie(
self,
mock_extract_access_token,
mock_logout_user,
mock_service_logout,
mock_current_account,
mock_db,
app,
mock_account,
):
"""
Test logout with access token in cookie.
Verifies that:
- Access token is extracted from the request
- Token is passed to AccountService.logout for revocation
- User session is terminated
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
test_access_token = "test.jwt.token"
# Mock extract_access_token to return a token
mock_extract_access_token.return_value = test_access_token
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = logout_api.post()
# Assert
# Verify extract_access_token was called with request
mock_extract_access_token.assert_called_once()
# Verify logout was called with the access token
mock_service_logout.assert_called_once_with(account=mock_account, access_token=test_access_token)
mock_logout_user.assert_called_once()
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
@patch("controllers.console.auth.login.extract_access_token")
def test_logout_with_access_token_in_authorization_header(
self,
mock_extract_access_token,
mock_logout_user,
mock_service_logout,
mock_current_account,
mock_db,
app,
mock_account,
):
"""
Test logout with access token in Authorization header.
Verifies that:
- Access token is extracted from Authorization header
- Token is passed to AccountService.logout for revocation
- User session is terminated
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
test_access_token = "bearer.token.from.header"
# Mock extract_access_token to return a token from header
mock_extract_access_token.return_value = test_access_token
# Act
with app.test_request_context(
"/logout", method="POST", headers={"Authorization": f"Bearer {test_access_token}"}
):
logout_api = LogoutApi()
response = logout_api.post()
# Assert
# Verify logout was called with the access token from header
mock_service_logout.assert_called_once_with(account=mock_account, access_token=test_access_token)
mock_logout_user.assert_called_once()
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
@patch("controllers.console.auth.login.extract_access_token")
def test_logout_with_no_access_token(
self,
mock_extract_access_token,
mock_logout_user,
mock_service_logout,
mock_current_account,
mock_db,
app,
mock_account,
):
"""
Test logout when no access token is present.
Verifies that:
- Logout proceeds without error when no token is present
- AccountService.logout is called with access_token=None
- User session is still terminated
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
# Mock extract_access_token to return None (no token in request)
mock_extract_access_token.return_value = None
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = logout_api.post()
# Assert
# Verify logout was called with access_token=None
mock_service_logout.assert_called_once_with(account=mock_account, access_token=None)
mock_logout_user.assert_called_once()
assert response.json["result"] == "success"

View File

@ -1,5 +1,5 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import jwt
import pytest
@ -35,9 +35,12 @@ class TestPassportService:
assert isinstance(token, str)
assert len(token.split(".")) == 3 # JWT format: header.payload.signature
# Verify token content
# Verify token content (jti is automatically added)
decoded = passport_service.verify(token)
assert decoded == payload
assert decoded["user_id"] == payload["user_id"]
assert decoded["app_code"] == payload["app_code"]
assert "jti" in decoded # jti is automatically added
assert isinstance(decoded["jti"], str)
def test_should_handle_different_payload_types(self, passport_service):
"""Test issuing and verifying tokens with different payload types"""
@ -57,7 +60,11 @@ class TestPassportService:
for payload in test_cases:
token = passport_service.issue(payload)
decoded = passport_service.verify(token)
assert decoded == payload
# Verify all original fields are present
for key, value in payload.items():
assert decoded[key] == value
# Verify jti is added
assert "jti" in decoded
# Security tests
def test_should_reject_modified_token(self, passport_service):
@ -153,7 +160,10 @@ class TestPassportService:
payload = {"test": "data"}
token = service.issue(payload)
decoded = service.verify(token)
assert decoded == payload
# Verify original payload fields are present
assert decoded["test"] == payload["test"]
# jti is automatically added
assert "jti" in decoded
def test_should_handle_none_secret_key(self):
"""Test behavior when SECRET_KEY is None"""
@ -192,7 +202,11 @@ class TestPassportService:
for payload in special_payloads:
token = passport_service.issue(payload)
decoded = passport_service.verify(token)
assert decoded == payload
# Verify all original fields are present
for key, value in payload.items():
assert decoded[key] == value
# jti is automatically added
assert "jti" in decoded
def test_should_catch_generic_pyjwt_errors(self, passport_service):
"""Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
@ -203,3 +217,133 @@ class TestPassportService:
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify("some-token")
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
# Token blacklist tests
def test_should_revoke_token_successfully(self, passport_service):
"""Test that tokens can be revoked and added to blacklist using jti"""
payload = {"user_id": "123", "exp": (datetime.now(UTC) + timedelta(hours=1)).timestamp()}
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = "test-secret-key-for-testing"
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
# Mock redis_client (via storage layer) and force storage backend to redis
mock_redis = MagicMock()
import libs.session_revocation_storage as srs
srs._singleton = None
with (
patch("libs.session_revocation_storage.dify_config") as s_cfg,
patch("libs.session_revocation_storage.redis_client", mock_redis),
):
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
result = passport_service.revoke(token)
assert result is True
mock_redis.setex.assert_called_once()
# Verify the key format uses jti
call_args = mock_redis.setex.call_args
key = call_args[0][0]
assert "passport:blacklist:jti:" in key
# Verify TTL is approximately 1 hour (3600 seconds)
ttl = call_args[0][1]
assert 3500 < ttl <= 3600 # Allow some tolerance
def test_should_not_revoke_already_expired_token(self, passport_service):
"""Test that already expired tokens are not added to blacklist"""
past_time = datetime.now(UTC) - timedelta(hours=1)
payload = {"user_id": "123", "exp": past_time.timestamp()}
with patch("libs.passport.dify_config") as mock_config:
mock_config.SECRET_KEY = "test-secret-key-for-testing"
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
mock_redis = MagicMock()
import libs.session_revocation_storage as srs
srs._singleton = None
with (
patch("libs.session_revocation_storage.dify_config") as s_cfg,
patch("libs.session_revocation_storage.redis_client", mock_redis),
):
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
result = passport_service.revoke(token)
assert result is False
mock_redis.setex.assert_not_called()
def test_should_reject_revoked_token(self, passport_service):
"""Test that revoked tokens cannot be verified"""
payload = {"user_id": "123"}
token = passport_service.issue(payload)
# Get the jti from the token
decoded = jwt.decode(token, options={"verify_signature": False})
jti = decoded.get("jti")
# Mock redis to simulate token being revoked (jti in blacklist)
mock_redis = MagicMock()
mock_redis.exists.return_value = True
with patch("libs.session_revocation_storage.redis_client", mock_redis):
import libs.session_revocation_storage as srs
srs._singleton = None
with patch("libs.session_revocation_storage.dify_config") as s_cfg:
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify(token)
assert "revoked" in str(exc_info.value).lower()
# Verify that jti was used to check blacklist
mock_redis.exists.assert_called_once_with(f"passport:blacklist:jti:{jti}")
def test_should_verify_non_revoked_token(self, passport_service):
"""Test that non-revoked tokens can be verified"""
payload = {"user_id": "123"}
token = passport_service.issue(payload)
# Get the jti from the token
decoded = jwt.decode(token, options={"verify_signature": False})
jti = decoded.get("jti")
# Mock redis to return False (token not in blacklist)
mock_redis = MagicMock()
mock_redis.exists.return_value = False
with patch("libs.session_revocation_storage.redis_client", mock_redis):
import libs.session_revocation_storage as srs
srs._singleton = None
with patch("libs.session_revocation_storage.dify_config") as s_cfg:
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
decoded = passport_service.verify(token)
assert decoded["user_id"] == payload["user_id"]
assert "jti" in decoded
mock_redis.exists.assert_called_once_with(f"passport:blacklist:jti:{jti}")
def test_should_handle_revoke_with_invalid_token(self, passport_service):
"""Test that revoke handles invalid tokens gracefully"""
invalid_token = "invalid.token.here"
mock_redis = MagicMock()
import libs.session_revocation_storage as srs
srs._singleton = None
with (
patch("libs.session_revocation_storage.dify_config") as s_cfg,
patch("libs.session_revocation_storage.redis_client", mock_redis),
):
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
result = passport_service.revoke(invalid_token)
assert result is False
mock_redis.setex.assert_not_called()
def test_should_generate_correct_blacklist_key(self, passport_service):
"""Test that blacklist key is generated correctly using jti"""
jti = "test-jti-uuid"
expected_key = f"passport:blacklist:jti:{jti}"
actual_key = passport_service._get_blacklist_key(jti)
assert actual_key == expected_key

View File

@ -0,0 +1,49 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
from libs.session_revocation_storage import (
NullSessionRevocationStorage,
RedisSessionRevocationStorage,
get_session_revocation_storage,
)
class TestSessionRevocationStorage:
def test_null_storage_behaviour(self):
s = NullSessionRevocationStorage()
s.revoke("any", datetime.now(UTC) + timedelta(hours=1))
assert s.is_revoked("any") is False
s.expunge()
def test_redis_storage_revoke_sets_ttl_and_is_revoked_true(self):
mock_redis = MagicMock()
storage = RedisSessionRevocationStorage()
token_id = "jti-abc"
exp = datetime.now(UTC) + timedelta(hours=1)
with patch("libs.session_revocation_storage.redis_client", mock_redis):
storage.revoke(token_id, exp)
assert mock_redis.setex.called
key, ttl, value = mock_redis.setex.call_args[0]
assert key == f"passport:blacklist:jti:{token_id}"
assert 3500 <= int(ttl) <= 3600
assert value in (b"1", "1")
mock_redis.exists.return_value = True
assert storage.is_revoked(token_id) is True
def test_factory_returns_null_by_default_and_redis_when_configured(self):
import libs.session_revocation_storage as srs
srs._singleton = None
with patch("libs.session_revocation_storage.dify_config") as cfg:
cfg.SESSION_REVOCATION_STORAGE = ""
inst = get_session_revocation_storage()
assert isinstance(inst, NullSessionRevocationStorage)
srs._singleton = None
with patch("libs.session_revocation_storage.dify_config") as cfg:
cfg.SESSION_REVOCATION_STORAGE = "redis"
inst = get_session_revocation_storage()
assert isinstance(inst, RedisSessionRevocationStorage)