mirror of https://github.com/langgenius/dify.git
Merge 7d4425b2dd into 508350ec6a
This commit is contained in:
commit
0418e7c371
|
|
@ -29,6 +29,14 @@ class SecurityConfig(BaseSettings):
|
|||
default="",
|
||||
)
|
||||
|
||||
SESSION_REVOCATION_STORAGE: str = Field(
|
||||
description=(
|
||||
"Session revocation storage backend for JWT jti revocation checks. "
|
||||
"Options: 'null' (default, disabled) or 'redis' (use configured Redis)."
|
||||
),
|
||||
default="null",
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
|
||||
description="Duration in minutes for which a password reset token remains valid",
|
||||
default=5,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from libs.token import (
|
|||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
clear_refresh_token_from_cookie,
|
||||
extract_access_token,
|
||||
extract_refresh_token,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
|
|
@ -157,7 +158,9 @@ class LogoutApi(Resource):
|
|||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
AccountService.logout(account=account)
|
||||
# Extract access_token from request to revoke it
|
||||
access_token = extract_access_token(request)
|
||||
AccountService.logout(account=account, access_token=access_token)
|
||||
flask_login.logout_user()
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,65 @@
|
|||
import time
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import jwt
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from libs.session_revocation_storage import get_session_revocation_storage
|
||||
|
||||
|
||||
def _get_blacklist_key(jti: str) -> str:
|
||||
return f"passport:blacklist:jti:{jti}"
|
||||
|
||||
|
||||
class PassportService:
|
||||
def __init__(self):
|
||||
self.sk = dify_config.SECRET_KEY
|
||||
|
||||
@classmethod
|
||||
def _get_blacklist_key(cls, jti: str) -> str:
|
||||
"""Instance-accessible helper for tests and internal use."""
|
||||
return _get_blacklist_key(jti)
|
||||
|
||||
def issue(self, payload):
|
||||
return jwt.encode(payload, self.sk, algorithm="HS256")
|
||||
# Add jti (JWT ID) if not present for token revocation support
|
||||
payload_to_encode = dict(payload)
|
||||
if "jti" not in payload_to_encode:
|
||||
payload_to_encode["jti"] = str(uuid.uuid4())
|
||||
return jwt.encode(payload_to_encode, self.sk, algorithm="HS256")
|
||||
|
||||
@classmethod
|
||||
def revoke(cls, token: str) -> bool:
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
except jwt.PyJWTError:
|
||||
# Invalid/garbled token: treat as non-revocable
|
||||
return False
|
||||
|
||||
token_id = payload.get("jti") or token
|
||||
exp = payload.get("exp")
|
||||
if not exp:
|
||||
return False
|
||||
|
||||
ttl = int(exp - time.time())
|
||||
if ttl <= 0:
|
||||
return False
|
||||
|
||||
storage = get_session_revocation_storage()
|
||||
storage.revoke(token_id, datetime.fromtimestamp(exp, tz=UTC))
|
||||
return True
|
||||
|
||||
def verify(self, token):
|
||||
"""Verify a JWT and then enforce revocation via SessionRevocationStorage.
|
||||
|
||||
The signature and standard claims are verified first to avoid any processing
|
||||
of untrusted data for invalid tokens. After successful verification we consult
|
||||
the configured revocation storage using the token's `jti` when present.
|
||||
"""
|
||||
# 1) Verify signature/claims first
|
||||
try:
|
||||
return jwt.decode(token, self.sk, algorithms=["HS256"])
|
||||
verified_payload = jwt.decode(token, self.sk, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise Unauthorized("Token has expired.")
|
||||
except jwt.InvalidSignatureError:
|
||||
|
|
@ -22,3 +68,11 @@ class PassportService:
|
|||
raise Unauthorized("Invalid token.")
|
||||
except jwt.PyJWTError: # Catch-all for other JWT errors
|
||||
raise Unauthorized("Invalid token.")
|
||||
|
||||
# 2) Enforce revocation using storage (supports old tokens without jti)
|
||||
storage = get_session_revocation_storage()
|
||||
token_id = verified_payload.get("jti") or token
|
||||
if storage.is_revoked(token_id):
|
||||
raise Unauthorized("Token has been revoked.")
|
||||
|
||||
return verified_payload
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SessionRevocationStorage(Protocol):
|
||||
def revoke(self, token_id: str, expiration_time: datetime) -> None: ...
|
||||
def is_revoked(self, token_id: str) -> bool: ...
|
||||
def expunge(self) -> None: ...
|
||||
|
||||
|
||||
class NullSessionRevocationStorage(SessionRevocationStorage):
|
||||
def revoke(self, token_id: str, expiration_time: datetime) -> None:
|
||||
return None
|
||||
|
||||
def is_revoked(self, token_id: str) -> bool:
|
||||
return False
|
||||
|
||||
def expunge(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class RedisSessionRevocationStorage(SessionRevocationStorage):
|
||||
def __init__(self, key_prefix: str = "passport:blacklist:jti:") -> None:
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
def _key(self, token_id: str) -> str:
|
||||
return f"{self.key_prefix}{token_id}"
|
||||
|
||||
def revoke(self, token_id: str, expiration_time: datetime) -> None:
|
||||
# Compute remaining lifetime in seconds
|
||||
now_ts = datetime.now(UTC).timestamp()
|
||||
ttl = int(expiration_time.timestamp() - now_ts)
|
||||
if ttl <= 0:
|
||||
return None
|
||||
redis_client.setex(self._key(token_id), ttl, b"1")
|
||||
return None
|
||||
|
||||
def is_revoked(self, token_id: str) -> bool:
|
||||
return bool(redis_client.exists(self._key(token_id)))
|
||||
|
||||
def expunge(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
_singleton: SessionRevocationStorage | None = None
|
||||
|
||||
|
||||
def get_session_revocation_storage() -> SessionRevocationStorage:
|
||||
global _singleton
|
||||
if _singleton is not None:
|
||||
return _singleton
|
||||
|
||||
backend = (dify_config.SESSION_REVOCATION_STORAGE or "null").strip().lower()
|
||||
if backend in ("", "null", "disabled", "off"):
|
||||
_singleton = NullSessionRevocationStorage()
|
||||
elif backend == "redis":
|
||||
_singleton = RedisSessionRevocationStorage()
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown SESSION_REVOCATION_STORAGE '%s'; falling back to 'null' (disabled).",
|
||||
backend,
|
||||
)
|
||||
_singleton = NullSessionRevocationStorage()
|
||||
return _singleton
|
||||
|
|
@ -440,7 +440,12 @@ class AccountService:
|
|||
return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token)
|
||||
|
||||
@staticmethod
|
||||
def logout(*, account: Account):
|
||||
def logout(*, account: Account, access_token: str | None = None):
|
||||
# Revoke access_token if provided
|
||||
if access_token:
|
||||
PassportService.revoke(access_token)
|
||||
|
||||
# Delete refresh_token
|
||||
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
|
||||
if refresh_token:
|
||||
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
||||
|
|
|
|||
|
|
@ -462,7 +462,7 @@ class TestLogoutApi:
|
|||
response = logout_api.post()
|
||||
|
||||
# Assert
|
||||
mock_service_logout.assert_called_once_with(account=mock_account)
|
||||
mock_service_logout.assert_called_once_with(account=mock_account, access_token=None)
|
||||
mock_logout_user.assert_called_once()
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
|
|
@ -493,3 +493,135 @@ class TestLogoutApi:
|
|||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.login.AccountService.logout")
|
||||
@patch("controllers.console.auth.login.flask_login.logout_user")
|
||||
@patch("controllers.console.auth.login.extract_access_token")
|
||||
def test_logout_with_access_token_in_cookie(
|
||||
self,
|
||||
mock_extract_access_token,
|
||||
mock_logout_user,
|
||||
mock_service_logout,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test logout with access token in cookie.
|
||||
|
||||
Verifies that:
|
||||
- Access token is extracted from the request
|
||||
- Token is passed to AccountService.logout for revocation
|
||||
- User session is terminated
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
test_access_token = "test.jwt.token"
|
||||
|
||||
# Mock extract_access_token to return a token
|
||||
mock_extract_access_token.return_value = test_access_token
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = logout_api.post()
|
||||
|
||||
# Assert
|
||||
# Verify extract_access_token was called with request
|
||||
mock_extract_access_token.assert_called_once()
|
||||
# Verify logout was called with the access token
|
||||
mock_service_logout.assert_called_once_with(account=mock_account, access_token=test_access_token)
|
||||
mock_logout_user.assert_called_once()
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.login.AccountService.logout")
|
||||
@patch("controllers.console.auth.login.flask_login.logout_user")
|
||||
@patch("controllers.console.auth.login.extract_access_token")
|
||||
def test_logout_with_access_token_in_authorization_header(
|
||||
self,
|
||||
mock_extract_access_token,
|
||||
mock_logout_user,
|
||||
mock_service_logout,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test logout with access token in Authorization header.
|
||||
|
||||
Verifies that:
|
||||
- Access token is extracted from Authorization header
|
||||
- Token is passed to AccountService.logout for revocation
|
||||
- User session is terminated
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
test_access_token = "bearer.token.from.header"
|
||||
|
||||
# Mock extract_access_token to return a token from header
|
||||
mock_extract_access_token.return_value = test_access_token
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/logout", method="POST", headers={"Authorization": f"Bearer {test_access_token}"}
|
||||
):
|
||||
logout_api = LogoutApi()
|
||||
response = logout_api.post()
|
||||
|
||||
# Assert
|
||||
# Verify logout was called with the access token from header
|
||||
mock_service_logout.assert_called_once_with(account=mock_account, access_token=test_access_token)
|
||||
mock_logout_user.assert_called_once()
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.login.AccountService.logout")
|
||||
@patch("controllers.console.auth.login.flask_login.logout_user")
|
||||
@patch("controllers.console.auth.login.extract_access_token")
|
||||
def test_logout_with_no_access_token(
|
||||
self,
|
||||
mock_extract_access_token,
|
||||
mock_logout_user,
|
||||
mock_service_logout,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test logout when no access token is present.
|
||||
|
||||
Verifies that:
|
||||
- Logout proceeds without error when no token is present
|
||||
- AccountService.logout is called with access_token=None
|
||||
- User session is still terminated
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
|
||||
# Mock extract_access_token to return None (no token in request)
|
||||
mock_extract_access_token.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = logout_api.post()
|
||||
|
||||
# Assert
|
||||
# Verify logout was called with access_token=None
|
||||
mock_service_logout.assert_called_once_with(account=mock_account, access_token=None)
|
||||
mock_logout_user.assert_called_once()
|
||||
assert response.json["result"] == "success"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
|
@ -35,9 +35,12 @@ class TestPassportService:
|
|||
assert isinstance(token, str)
|
||||
assert len(token.split(".")) == 3 # JWT format: header.payload.signature
|
||||
|
||||
# Verify token content
|
||||
# Verify token content (jti is automatically added)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
assert decoded["user_id"] == payload["user_id"]
|
||||
assert decoded["app_code"] == payload["app_code"]
|
||||
assert "jti" in decoded # jti is automatically added
|
||||
assert isinstance(decoded["jti"], str)
|
||||
|
||||
def test_should_handle_different_payload_types(self, passport_service):
|
||||
"""Test issuing and verifying tokens with different payload types"""
|
||||
|
|
@ -57,7 +60,11 @@ class TestPassportService:
|
|||
for payload in test_cases:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
# Verify all original fields are present
|
||||
for key, value in payload.items():
|
||||
assert decoded[key] == value
|
||||
# Verify jti is added
|
||||
assert "jti" in decoded
|
||||
|
||||
# Security tests
|
||||
def test_should_reject_modified_token(self, passport_service):
|
||||
|
|
@ -153,7 +160,10 @@ class TestPassportService:
|
|||
payload = {"test": "data"}
|
||||
token = service.issue(payload)
|
||||
decoded = service.verify(token)
|
||||
assert decoded == payload
|
||||
# Verify original payload fields are present
|
||||
assert decoded["test"] == payload["test"]
|
||||
# jti is automatically added
|
||||
assert "jti" in decoded
|
||||
|
||||
def test_should_handle_none_secret_key(self):
|
||||
"""Test behavior when SECRET_KEY is None"""
|
||||
|
|
@ -192,7 +202,11 @@ class TestPassportService:
|
|||
for payload in special_payloads:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
# Verify all original fields are present
|
||||
for key, value in payload.items():
|
||||
assert decoded[key] == value
|
||||
# jti is automatically added
|
||||
assert "jti" in decoded
|
||||
|
||||
def test_should_catch_generic_pyjwt_errors(self, passport_service):
|
||||
"""Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
|
||||
|
|
@ -203,3 +217,133 @@ class TestPassportService:
|
|||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify("some-token")
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
|
||||
|
||||
# Token blacklist tests
|
||||
def test_should_revoke_token_successfully(self, passport_service):
|
||||
"""Test that tokens can be revoked and added to blacklist using jti"""
|
||||
payload = {"user_id": "123", "exp": (datetime.now(UTC) + timedelta(hours=1)).timestamp()}
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
# Mock redis_client (via storage layer) and force storage backend to redis
|
||||
mock_redis = MagicMock()
|
||||
import libs.session_revocation_storage as srs
|
||||
|
||||
srs._singleton = None
|
||||
with (
|
||||
patch("libs.session_revocation_storage.dify_config") as s_cfg,
|
||||
patch("libs.session_revocation_storage.redis_client", mock_redis),
|
||||
):
|
||||
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
|
||||
result = passport_service.revoke(token)
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called_once()
|
||||
# Verify the key format uses jti
|
||||
call_args = mock_redis.setex.call_args
|
||||
key = call_args[0][0]
|
||||
assert "passport:blacklist:jti:" in key
|
||||
# Verify TTL is approximately 1 hour (3600 seconds)
|
||||
ttl = call_args[0][1]
|
||||
assert 3500 < ttl <= 3600 # Allow some tolerance
|
||||
|
||||
def test_should_not_revoke_already_expired_token(self, passport_service):
|
||||
"""Test that already expired tokens are not added to blacklist"""
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
payload = {"user_id": "123", "exp": past_time.timestamp()}
|
||||
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
mock_redis = MagicMock()
|
||||
import libs.session_revocation_storage as srs
|
||||
|
||||
srs._singleton = None
|
||||
with (
|
||||
patch("libs.session_revocation_storage.dify_config") as s_cfg,
|
||||
patch("libs.session_revocation_storage.redis_client", mock_redis),
|
||||
):
|
||||
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
|
||||
result = passport_service.revoke(token)
|
||||
|
||||
assert result is False
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
def test_should_reject_revoked_token(self, passport_service):
|
||||
"""Test that revoked tokens cannot be verified"""
|
||||
payload = {"user_id": "123"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Get the jti from the token
|
||||
decoded = jwt.decode(token, options={"verify_signature": False})
|
||||
jti = decoded.get("jti")
|
||||
|
||||
# Mock redis to simulate token being revoked (jti in blacklist)
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.exists.return_value = True
|
||||
|
||||
with patch("libs.session_revocation_storage.redis_client", mock_redis):
|
||||
import libs.session_revocation_storage as srs
|
||||
|
||||
srs._singleton = None
|
||||
with patch("libs.session_revocation_storage.dify_config") as s_cfg:
|
||||
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(token)
|
||||
|
||||
assert "revoked" in str(exc_info.value).lower()
|
||||
# Verify that jti was used to check blacklist
|
||||
mock_redis.exists.assert_called_once_with(f"passport:blacklist:jti:{jti}")
|
||||
|
||||
def test_should_verify_non_revoked_token(self, passport_service):
|
||||
"""Test that non-revoked tokens can be verified"""
|
||||
payload = {"user_id": "123"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Get the jti from the token
|
||||
decoded = jwt.decode(token, options={"verify_signature": False})
|
||||
jti = decoded.get("jti")
|
||||
|
||||
# Mock redis to return False (token not in blacklist)
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.exists.return_value = False
|
||||
|
||||
with patch("libs.session_revocation_storage.redis_client", mock_redis):
|
||||
import libs.session_revocation_storage as srs
|
||||
|
||||
srs._singleton = None
|
||||
with patch("libs.session_revocation_storage.dify_config") as s_cfg:
|
||||
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
|
||||
decoded = passport_service.verify(token)
|
||||
|
||||
assert decoded["user_id"] == payload["user_id"]
|
||||
assert "jti" in decoded
|
||||
mock_redis.exists.assert_called_once_with(f"passport:blacklist:jti:{jti}")
|
||||
|
||||
def test_should_handle_revoke_with_invalid_token(self, passport_service):
|
||||
"""Test that revoke handles invalid tokens gracefully"""
|
||||
invalid_token = "invalid.token.here"
|
||||
|
||||
mock_redis = MagicMock()
|
||||
import libs.session_revocation_storage as srs
|
||||
|
||||
srs._singleton = None
|
||||
with (
|
||||
patch("libs.session_revocation_storage.dify_config") as s_cfg,
|
||||
patch("libs.session_revocation_storage.redis_client", mock_redis),
|
||||
):
|
||||
s_cfg.SESSION_REVOCATION_STORAGE = "redis"
|
||||
result = passport_service.revoke(invalid_token)
|
||||
|
||||
assert result is False
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
def test_should_generate_correct_blacklist_key(self, passport_service):
|
||||
"""Test that blacklist key is generated correctly using jti"""
|
||||
jti = "test-jti-uuid"
|
||||
expected_key = f"passport:blacklist:jti:{jti}"
|
||||
|
||||
actual_key = passport_service._get_blacklist_key(jti)
|
||||
assert actual_key == expected_key
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from libs.session_revocation_storage import (
|
||||
NullSessionRevocationStorage,
|
||||
RedisSessionRevocationStorage,
|
||||
get_session_revocation_storage,
|
||||
)
|
||||
|
||||
|
||||
class TestSessionRevocationStorage:
|
||||
def test_null_storage_behaviour(self):
|
||||
s = NullSessionRevocationStorage()
|
||||
s.revoke("any", datetime.now(UTC) + timedelta(hours=1))
|
||||
assert s.is_revoked("any") is False
|
||||
s.expunge()
|
||||
|
||||
def test_redis_storage_revoke_sets_ttl_and_is_revoked_true(self):
|
||||
mock_redis = MagicMock()
|
||||
storage = RedisSessionRevocationStorage()
|
||||
|
||||
token_id = "jti-abc"
|
||||
exp = datetime.now(UTC) + timedelta(hours=1)
|
||||
|
||||
with patch("libs.session_revocation_storage.redis_client", mock_redis):
|
||||
storage.revoke(token_id, exp)
|
||||
assert mock_redis.setex.called
|
||||
key, ttl, value = mock_redis.setex.call_args[0]
|
||||
assert key == f"passport:blacklist:jti:{token_id}"
|
||||
assert 3500 <= int(ttl) <= 3600
|
||||
assert value in (b"1", "1")
|
||||
|
||||
mock_redis.exists.return_value = True
|
||||
assert storage.is_revoked(token_id) is True
|
||||
|
||||
def test_factory_returns_null_by_default_and_redis_when_configured(self):
|
||||
import libs.session_revocation_storage as srs
|
||||
|
||||
srs._singleton = None
|
||||
with patch("libs.session_revocation_storage.dify_config") as cfg:
|
||||
cfg.SESSION_REVOCATION_STORAGE = ""
|
||||
inst = get_session_revocation_storage()
|
||||
assert isinstance(inst, NullSessionRevocationStorage)
|
||||
|
||||
srs._singleton = None
|
||||
with patch("libs.session_revocation_storage.dify_config") as cfg:
|
||||
cfg.SESSION_REVOCATION_STORAGE = "redis"
|
||||
inst = get_session_revocation_storage()
|
||||
assert isinstance(inst, RedisSessionRevocationStorage)
|
||||
Loading…
Reference in New Issue