From 18a52caef06eedca6a05a7fb9122c0af47d54c21 Mon Sep 17 00:00:00 2001 From: fatelei Date: Sat, 31 Jan 2026 21:47:47 +0800 Subject: [PATCH 1/3] fix: fix access token is still valid after logout --- api/controllers/console/auth/login.py | 5 +- api/libs/passport.py | 70 ++++++++- api/services/account_service.py | 7 +- .../console/auth/test_login_logout.py | 134 +++++++++++++++++- api/tests/unit_tests/libs/test_passport.py | 127 ++++++++++++++++- 5 files changed, 332 insertions(+), 11 deletions(-) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 400df138b8..35e6b8a93f 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -37,6 +37,7 @@ from libs.token import ( clear_access_token_from_cookie, clear_csrf_token_from_cookie, clear_refresh_token_from_cookie, + extract_access_token, extract_refresh_token, set_access_token_to_cookie, set_csrf_token_to_cookie, @@ -157,7 +158,9 @@ class LogoutApi(Resource): if isinstance(account, flask_login.AnonymousUserMixin): response = make_response({"result": "success"}) else: - AccountService.logout(account=account) + # Extract access_token from request to revoke it + access_token = extract_access_token(request) + AccountService.logout(account=account, access_token=access_token) flask_login.logout_user() response = make_response({"result": "success"}) diff --git a/api/libs/passport.py b/api/libs/passport.py index 22dd20b73b..7c494e67d4 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -1,19 +1,73 @@ +import time +import uuid + import jwt from werkzeug.exceptions import Unauthorized from configs import dify_config +from extensions.ext_redis import redis_client + + +def _get_blacklist_key(jti: str) -> str: + """Generate Redis key for token blacklist using JWT ID.""" + return f"passport:blacklist:jti:{jti}" class PassportService: def __init__(self): self.sk = dify_config.SECRET_KEY + @classmethod + def _get_blacklist_key(cls, jti: str) -> str: + """Instance-accessible helper for tests and internal use.""" + return _get_blacklist_key(jti) + def issue(self, payload): - return jwt.encode(payload, self.sk, algorithm="HS256") + # Add jti (JWT ID) if not present for token revocation support + payload_to_encode = dict(payload) + if "jti" not in payload_to_encode: + payload_to_encode["jti"] = str(uuid.uuid4()) + return jwt.encode(payload_to_encode, self.sk, algorithm="HS256") + + @classmethod + def revoke(cls, token: str) -> bool: + """Add token to blacklist until its expiration using JWT ID (jti). + + Returns False if the token is invalid, missing exp/jti, or already expired. + """ + try: + payload = jwt.decode(token, options={"verify_signature": False}) + except jwt.PyJWTError: + # Invalid/garbled token: treat as non-revocable + return False + + jti = payload.get("jti") + if not jti: + # Fallback for tokens without jti (old format) + # Use the full token as key for backward compatibility + jti = token + + exp = payload.get("exp") + if not exp: + return False + + ttl = int(exp - time.time()) + if ttl <= 0: + return False + + redis_client.setex(cls._get_blacklist_key(jti), ttl, "1") + return True def verify(self, token): + """Verify a JWT and then enforce revocation via Redis blacklist. + + The signature and standard claims are verified first to avoid any processing + of untrusted data (including Redis lookups) for invalid tokens. Only after a + successful verification do we consult the blacklist using the token's `jti`. + """ + # 1) Verify signature/claims first try: - return jwt.decode(token, self.sk, algorithms=["HS256"]) + verified_payload = jwt.decode(token, self.sk, algorithms=["HS256"]) except jwt.ExpiredSignatureError: raise Unauthorized("Token has expired.") except jwt.InvalidSignatureError: @@ -22,3 +76,15 @@ class PassportService: raise Unauthorized("Invalid token.") except jwt.PyJWTError: # Catch-all for other JWT errors raise Unauthorized("Invalid token.") + + # 2) Enforce revocation via blacklist using jti (if present) + jti = verified_payload.get("jti") + if jti: + if redis_client.exists(self._get_blacklist_key(jti)): + raise Unauthorized("Token has been revoked.") + else: + # Fallback for old tokens without jti + if redis_client.exists(self._get_blacklist_key(token)): + raise Unauthorized("Token has been revoked.") + + return verified_payload diff --git a/api/services/account_service.py b/api/services/account_service.py index bd520f54cf..dbe8c3af00 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -440,7 +440,12 @@ class AccountService: return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token) @staticmethod - def logout(*, account: Account): + def logout(*, account: Account, access_token: str | None = None): + # Revoke access_token if provided + if access_token: + PassportService.revoke(access_token) + + # Delete refresh_token refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) if refresh_token: AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 560971206f..eed7811a49 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -462,7 +462,7 @@ class TestLogoutApi: response = logout_api.post() # Assert - mock_service_logout.assert_called_once_with(account=mock_account) + mock_service_logout.assert_called_once_with(account=mock_account, access_token=None) mock_logout_user.assert_called_once() assert response.json["result"] == "success" @@ -493,3 +493,135 @@ class TestLogoutApi: # Assert assert response.json["result"] == "success" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.current_account_with_tenant") + @patch("controllers.console.auth.login.AccountService.logout") + @patch("controllers.console.auth.login.flask_login.logout_user") + @patch("controllers.console.auth.login.extract_access_token") + def test_logout_with_access_token_in_cookie( + self, + mock_extract_access_token, + mock_logout_user, + mock_service_logout, + mock_current_account, + mock_db, + app, + mock_account, + ): + """ + Test logout with access token in cookie. + + Verifies that: + - Access token is extracted from the request + - Token is passed to AccountService.logout for revocation + - User session is terminated + - Success response is returned + """ + # Arrange + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_current_account.return_value = (mock_account, MagicMock()) + test_access_token = "test.jwt.token" + + # Mock extract_access_token to return a token + mock_extract_access_token.return_value = test_access_token + + # Act + with app.test_request_context("/logout", method="POST"): + logout_api = LogoutApi() + response = logout_api.post() + + # Assert + # Verify extract_access_token was called with request + mock_extract_access_token.assert_called_once() + # Verify logout was called with the access token + mock_service_logout.assert_called_once_with(account=mock_account, access_token=test_access_token) + mock_logout_user.assert_called_once() + assert response.json["result"] == "success" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.current_account_with_tenant") + @patch("controllers.console.auth.login.AccountService.logout") + @patch("controllers.console.auth.login.flask_login.logout_user") + @patch("controllers.console.auth.login.extract_access_token") + def test_logout_with_access_token_in_authorization_header( + self, + mock_extract_access_token, + mock_logout_user, + mock_service_logout, + mock_current_account, + mock_db, + app, + mock_account, + ): + """ + Test logout with access token in Authorization header. + + Verifies that: + - Access token is extracted from Authorization header + - Token is passed to AccountService.logout for revocation + - User session is terminated + - Success response is returned + """ + # Arrange + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_current_account.return_value = (mock_account, MagicMock()) + test_access_token = "bearer.token.from.header" + + # Mock extract_access_token to return a token from header + mock_extract_access_token.return_value = test_access_token + + # Act + with app.test_request_context( + "/logout", method="POST", headers={"Authorization": f"Bearer {test_access_token}"} + ): + logout_api = LogoutApi() + response = logout_api.post() + + # Assert + # Verify logout was called with the access token from header + mock_service_logout.assert_called_once_with(account=mock_account, access_token=test_access_token) + mock_logout_user.assert_called_once() + assert response.json["result"] == "success" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.current_account_with_tenant") + @patch("controllers.console.auth.login.AccountService.logout") + @patch("controllers.console.auth.login.flask_login.logout_user") + @patch("controllers.console.auth.login.extract_access_token") + def test_logout_with_no_access_token( + self, + mock_extract_access_token, + mock_logout_user, + mock_service_logout, + mock_current_account, + mock_db, + app, + mock_account, + ): + """ + Test logout when no access token is present. + + Verifies that: + - Logout proceeds without error when no token is present + - AccountService.logout is called with access_token=None + - User session is still terminated + - Success response is returned + """ + # Arrange + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_current_account.return_value = (mock_account, MagicMock()) + + # Mock extract_access_token to return None (no token in request) + mock_extract_access_token.return_value = None + + # Act + with app.test_request_context("/logout", method="POST"): + logout_api = LogoutApi() + response = logout_api.post() + + # Assert + # Verify logout was called with access_token=None + mock_service_logout.assert_called_once_with(account=mock_account, access_token=None) + mock_logout_user.assert_called_once() + assert response.json["result"] == "success" diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py index f33484c18d..20ccb38b2b 100644 --- a/api/tests/unit_tests/libs/test_passport.py +++ b/api/tests/unit_tests/libs/test_passport.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime, timedelta -from unittest.mock import patch +from unittest.mock import MagicMock, patch import jwt import pytest @@ -35,9 +35,12 @@ class TestPassportService: assert isinstance(token, str) assert len(token.split(".")) == 3 # JWT format: header.payload.signature - # Verify token content + # Verify token content (jti is automatically added) decoded = passport_service.verify(token) - assert decoded == payload + assert decoded["user_id"] == payload["user_id"] + assert decoded["app_code"] == payload["app_code"] + assert "jti" in decoded # jti is automatically added + assert isinstance(decoded["jti"], str) def test_should_handle_different_payload_types(self, passport_service): """Test issuing and verifying tokens with different payload types""" @@ -57,7 +60,11 @@ class TestPassportService: for payload in test_cases: token = passport_service.issue(payload) decoded = passport_service.verify(token) - assert decoded == payload + # Verify all original fields are present + for key, value in payload.items(): + assert decoded[key] == value + # Verify jti is added + assert "jti" in decoded # Security tests def test_should_reject_modified_token(self, passport_service): @@ -153,7 +160,10 @@ class TestPassportService: payload = {"test": "data"} token = service.issue(payload) decoded = service.verify(token) - assert decoded == payload + # Verify original payload fields are present + assert decoded["test"] == payload["test"] + # jti is automatically added + assert "jti" in decoded def test_should_handle_none_secret_key(self): """Test behavior when SECRET_KEY is None""" @@ -192,7 +202,11 @@ class TestPassportService: for payload in special_payloads: token = passport_service.issue(payload) decoded = passport_service.verify(token) - assert decoded == payload + # Verify all original fields are present + for key, value in payload.items(): + assert decoded[key] == value + # jti is automatically added + assert "jti" in decoded def test_should_catch_generic_pyjwt_errors(self, passport_service): """Test that generic PyJWTError exceptions are caught and converted to Unauthorized""" @@ -203,3 +217,104 @@ class TestPassportService: with pytest.raises(Unauthorized) as exc_info: passport_service.verify("some-token") assert str(exc_info.value) == "401 Unauthorized: Invalid token." + + # Token blacklist tests + def test_should_revoke_token_successfully(self, passport_service): + """Test that tokens can be revoked and added to blacklist using jti""" + payload = {"user_id": "123", "exp": (datetime.now(UTC) + timedelta(hours=1)).timestamp()} + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = "test-secret-key-for-testing" + token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256") + + # Mock redis_client + mock_redis = MagicMock() + with patch("libs.passport.redis_client", mock_redis): + result = passport_service.revoke(token) + + assert result is True + mock_redis.setex.assert_called_once() + # Verify the key format uses jti + call_args = mock_redis.setex.call_args + key = call_args[0][0] + assert "passport:blacklist:jti:" in key + # Verify TTL is approximately 1 hour (3600 seconds) + ttl = call_args[0][1] + assert 3500 < ttl <= 3600 # Allow some tolerance + + def test_should_not_revoke_already_expired_token(self, passport_service): + """Test that already expired tokens are not added to blacklist""" + past_time = datetime.now(UTC) - timedelta(hours=1) + payload = {"user_id": "123", "exp": past_time.timestamp()} + + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = "test-secret-key-for-testing" + token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256") + + # Mock redis_client + mock_redis = MagicMock() + with patch("libs.passport.redis_client", mock_redis): + result = passport_service.revoke(token) + + assert result is False + mock_redis.setex.assert_not_called() + + def test_should_reject_revoked_token(self, passport_service): + """Test that revoked tokens cannot be verified""" + payload = {"user_id": "123"} + token = passport_service.issue(payload) + + # Get the jti from the token + decoded = jwt.decode(token, options={"verify_signature": False}) + jti = decoded.get("jti") + + # Mock redis to simulate token being revoked (jti in blacklist) + mock_redis = MagicMock() + mock_redis.exists.return_value = True + + with patch("libs.passport.redis_client", mock_redis): + with pytest.raises(Unauthorized) as exc_info: + passport_service.verify(token) + + assert "revoked" in str(exc_info.value).lower() + # Verify that jti was used to check blacklist + mock_redis.exists.assert_called_once_with(f"passport:blacklist:jti:{jti}") + + def test_should_verify_non_revoked_token(self, passport_service): + """Test that non-revoked tokens can be verified""" + payload = {"user_id": "123"} + token = passport_service.issue(payload) + + # Get the jti from the token + decoded = jwt.decode(token, options={"verify_signature": False}) + jti = decoded.get("jti") + + # Mock redis to return False (token not in blacklist) + mock_redis = MagicMock() + mock_redis.exists.return_value = False + + with patch("libs.passport.redis_client", mock_redis): + decoded = passport_service.verify(token) + + assert decoded["user_id"] == payload["user_id"] + assert "jti" in decoded + mock_redis.exists.assert_called_once_with(f"passport:blacklist:jti:{jti}") + + def test_should_handle_revoke_with_invalid_token(self, passport_service): + """Test that revoke handles invalid tokens gracefully""" + invalid_token = "invalid.token.here" + + # Mock redis_client + mock_redis = MagicMock() + with patch("libs.passport.redis_client", mock_redis): + result = passport_service.revoke(invalid_token) + + assert result is False + mock_redis.setex.assert_not_called() + + def test_should_generate_correct_blacklist_key(self, passport_service): + """Test that blacklist key is generated correctly using jti""" + jti = "test-jti-uuid" + expected_key = f"passport:blacklist:jti:{jti}" + + actual_key = passport_service._get_blacklist_key(jti) + assert actual_key == expected_key From b0977375603186f93c9dd3325110c3705c312b3f Mon Sep 17 00:00:00 2001 From: fatelei Date: Tue, 17 Mar 2026 14:13:00 +0800 Subject: [PATCH 2/3] feat: add session revoke storage --- api/configs/feature/__init__.py | 8 ++ api/libs/passport.py | 38 ++++------ api/libs/session_revocation_storage.py | 74 +++++++++++++++++++ api/tests/unit_tests/libs/test_passport.py | 40 +++++++--- .../libs/test_session_revocation_storage.py | 50 +++++++++++++ 5 files changed, 174 insertions(+), 36 deletions(-) create mode 100644 api/libs/session_revocation_storage.py create mode 100644 api/tests/unit_tests/libs/test_session_revocation_storage.py diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d37cff63e9..e607c463bc 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -29,6 +29,14 @@ class SecurityConfig(BaseSettings): default="", ) + SESSION_REVOCATION_STORAGE: str = Field( + description=( + "Session revocation storage backend for JWT jti revocation checks. " + "Options: 'null' (default, disabled) or 'redis' (use configured Redis)." + ), + default="null", + ) + RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( description="Duration in minutes for which a password reset token remains valid", default=5, diff --git a/api/libs/passport.py b/api/libs/passport.py index 7c494e67d4..7b23a32ee2 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -1,15 +1,15 @@ import time import uuid +from datetime import UTC, datetime import jwt from werkzeug.exceptions import Unauthorized from configs import dify_config -from extensions.ext_redis import redis_client +from libs.session_revocation_storage import get_session_revocation_storage def _get_blacklist_key(jti: str) -> str: - """Generate Redis key for token blacklist using JWT ID.""" return f"passport:blacklist:jti:{jti}" @@ -31,22 +31,13 @@ class PassportService: @classmethod def revoke(cls, token: str) -> bool: - """Add token to blacklist until its expiration using JWT ID (jti). - - Returns False if the token is invalid, missing exp/jti, or already expired. - """ try: payload = jwt.decode(token, options={"verify_signature": False}) except jwt.PyJWTError: # Invalid/garbled token: treat as non-revocable return False - jti = payload.get("jti") - if not jti: - # Fallback for tokens without jti (old format) - # Use the full token as key for backward compatibility - jti = token - + token_id = payload.get("jti") or token exp = payload.get("exp") if not exp: return False @@ -55,15 +46,16 @@ class PassportService: if ttl <= 0: return False - redis_client.setex(cls._get_blacklist_key(jti), ttl, "1") + storage = get_session_revocation_storage() + storage.revoke(token_id, datetime.fromtimestamp(exp, tz=UTC)) return True def verify(self, token): - """Verify a JWT and then enforce revocation via Redis blacklist. + """Verify a JWT and then enforce revocation via SessionRevocationStorage. The signature and standard claims are verified first to avoid any processing - of untrusted data (including Redis lookups) for invalid tokens. Only after a - successful verification do we consult the blacklist using the token's `jti`. + of untrusted data for invalid tokens. After successful verification we consult + the configured revocation storage using the token's `jti` when present. """ # 1) Verify signature/claims first try: @@ -77,14 +69,10 @@ class PassportService: except jwt.PyJWTError: # Catch-all for other JWT errors raise Unauthorized("Invalid token.") - # 2) Enforce revocation via blacklist using jti (if present) - jti = verified_payload.get("jti") - if jti: - if redis_client.exists(self._get_blacklist_key(jti)): - raise Unauthorized("Token has been revoked.") - else: - # Fallback for old tokens without jti - if redis_client.exists(self._get_blacklist_key(token)): - raise Unauthorized("Token has been revoked.") + # 2) Enforce revocation using storage (supports old tokens without jti) + storage = get_session_revocation_storage() + token_id = verified_payload.get("jti") or token + if storage.is_revoked(token_id): + raise Unauthorized("Token has been revoked.") return verified_payload diff --git a/api/libs/session_revocation_storage.py b/api/libs/session_revocation_storage.py new file mode 100644 index 0000000000..7bef07b55d --- /dev/null +++ b/api/libs/session_revocation_storage.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import Protocol, runtime_checkable + +from configs import dify_config +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class SessionRevocationStorage(Protocol): + def revoke(self, token_id: str, expiration_time: datetime) -> None: ... + def is_revoked(self, token_id: str) -> bool: ... + def expunge(self) -> None: ... + + +class NullSessionRevocationStorage(SessionRevocationStorage): + def revoke(self, token_id: str, expiration_time: datetime) -> None: + return None + + def is_revoked(self, token_id: str) -> bool: + return False + + def expunge(self) -> None: + return None + + +class RedisSessionRevocationStorage(SessionRevocationStorage): + + def __init__(self, key_prefix: str = "passport:blacklist:jti:") -> None: + self.key_prefix = key_prefix + + def _key(self, token_id: str) -> str: + return f"{self.key_prefix}{token_id}" + + def revoke(self, token_id: str, expiration_time: datetime) -> None: + # Compute remaining lifetime in seconds + now_ts = datetime.now(UTC).timestamp() + ttl = int(expiration_time.timestamp() - now_ts) + if ttl <= 0: + return None + redis_client.setex(self._key(token_id), ttl, b"1") + return None + + def is_revoked(self, token_id: str) -> bool: + return bool(redis_client.exists(self._key(token_id))) + + def expunge(self) -> None: + return None + + +_singleton: SessionRevocationStorage | None = None + + +def get_session_revocation_storage() -> SessionRevocationStorage: + global _singleton + if _singleton is not None: + return _singleton + + backend = (dify_config.SESSION_REVOCATION_STORAGE or "null").strip().lower() + if backend in ("", "null", "disabled", "off"): + _singleton = NullSessionRevocationStorage() + elif backend == "redis": + _singleton = RedisSessionRevocationStorage() + else: + logger.warning( + "Unknown SESSION_REVOCATION_STORAGE '%s'; falling back to 'null' (disabled).", + backend, + ) + _singleton = NullSessionRevocationStorage() + return _singleton diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py index 20ccb38b2b..9e76e0c0fb 100644 --- a/api/tests/unit_tests/libs/test_passport.py +++ b/api/tests/unit_tests/libs/test_passport.py @@ -226,9 +226,13 @@ class TestPassportService: mock_config.SECRET_KEY = "test-secret-key-for-testing" token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256") - # Mock redis_client + # Mock redis_client (via storage layer) and force storage backend to redis mock_redis = MagicMock() - with patch("libs.passport.redis_client", mock_redis): + import libs.session_revocation_storage as srs + srs._singleton = None + with patch("libs.session_revocation_storage.dify_config") as s_cfg, \ + patch("libs.session_revocation_storage.redis_client", mock_redis): + s_cfg.SESSION_REVOCATION_STORAGE = "redis" result = passport_service.revoke(token) assert result is True @@ -250,9 +254,12 @@ class TestPassportService: mock_config.SECRET_KEY = "test-secret-key-for-testing" token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256") - # Mock redis_client mock_redis = MagicMock() - with patch("libs.passport.redis_client", mock_redis): + import libs.session_revocation_storage as srs + srs._singleton = None + with patch("libs.session_revocation_storage.dify_config") as s_cfg, \ + patch("libs.session_revocation_storage.redis_client", mock_redis): + s_cfg.SESSION_REVOCATION_STORAGE = "redis" result = passport_service.revoke(token) assert result is False @@ -271,9 +278,13 @@ class TestPassportService: mock_redis = MagicMock() mock_redis.exists.return_value = True - with patch("libs.passport.redis_client", mock_redis): - with pytest.raises(Unauthorized) as exc_info: - passport_service.verify(token) + with patch("libs.session_revocation_storage.redis_client", mock_redis): + import libs.session_revocation_storage as srs + srs._singleton = None + with patch("libs.session_revocation_storage.dify_config") as s_cfg: + s_cfg.SESSION_REVOCATION_STORAGE = "redis" + with pytest.raises(Unauthorized) as exc_info: + passport_service.verify(token) assert "revoked" in str(exc_info.value).lower() # Verify that jti was used to check blacklist @@ -292,8 +303,12 @@ class TestPassportService: mock_redis = MagicMock() mock_redis.exists.return_value = False - with patch("libs.passport.redis_client", mock_redis): - decoded = passport_service.verify(token) + with patch("libs.session_revocation_storage.redis_client", mock_redis): + import libs.session_revocation_storage as srs + srs._singleton = None + with patch("libs.session_revocation_storage.dify_config") as s_cfg: + s_cfg.SESSION_REVOCATION_STORAGE = "redis" + decoded = passport_service.verify(token) assert decoded["user_id"] == payload["user_id"] assert "jti" in decoded @@ -303,9 +318,12 @@ class TestPassportService: """Test that revoke handles invalid tokens gracefully""" invalid_token = "invalid.token.here" - # Mock redis_client mock_redis = MagicMock() - with patch("libs.passport.redis_client", mock_redis): + import libs.session_revocation_storage as srs + srs._singleton = None + with patch("libs.session_revocation_storage.dify_config") as s_cfg, \ + patch("libs.session_revocation_storage.redis_client", mock_redis): + s_cfg.SESSION_REVOCATION_STORAGE = "redis" result = passport_service.revoke(invalid_token) assert result is False diff --git a/api/tests/unit_tests/libs/test_session_revocation_storage.py b/api/tests/unit_tests/libs/test_session_revocation_storage.py new file mode 100644 index 0000000000..f70be44972 --- /dev/null +++ b/api/tests/unit_tests/libs/test_session_revocation_storage.py @@ -0,0 +1,50 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch + +from libs.session_revocation_storage import ( + NullSessionRevocationStorage, + RedisSessionRevocationStorage, + get_session_revocation_storage, +) + + +class TestSessionRevocationStorage: + def test_null_storage_behaviour(self): + s = NullSessionRevocationStorage() + s.revoke("any", datetime.now(UTC) + timedelta(hours=1)) + assert s.is_revoked("any") is False + s.expunge() + + def test_redis_storage_revoke_sets_ttl_and_is_revoked_true(self): + mock_redis = MagicMock() + storage = RedisSessionRevocationStorage() + + token_id = "jti-abc" + exp = datetime.now(UTC) + timedelta(hours=1) + + with patch("libs.session_revocation_storage.redis_client", mock_redis): + storage.revoke(token_id, exp) + assert mock_redis.setex.called + key, ttl, value = mock_redis.setex.call_args[0] + assert key == f"passport:blacklist:jti:{token_id}" + assert 3500 <= int(ttl) <= 3600 + assert value in (b"1", "1") + + mock_redis.exists.return_value = True + assert storage.is_revoked(token_id) is True + + def test_factory_returns_null_by_default_and_redis_when_configured(self): + import libs.session_revocation_storage as srs + + srs._singleton = None + with patch("libs.session_revocation_storage.dify_config") as cfg: + cfg.SESSION_REVOCATION_STORAGE = "" + inst = get_session_revocation_storage() + assert isinstance(inst, NullSessionRevocationStorage) + + srs._singleton = None + with patch("libs.session_revocation_storage.dify_config") as cfg: + cfg.SESSION_REVOCATION_STORAGE = "redis" + inst = get_session_revocation_storage() + assert isinstance(inst, RedisSessionRevocationStorage) + From 7d4425b2ddf43140c60df8e0386cd7329240d189 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 07:08:23 +0000 Subject: [PATCH 3/3] [autofix.ci] apply automated fixes --- api/libs/session_revocation_storage.py | 1 - api/tests/unit_tests/libs/test_passport.py | 23 ++++++++++++++----- .../libs/test_session_revocation_storage.py | 1 - 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/api/libs/session_revocation_storage.py b/api/libs/session_revocation_storage.py index 7bef07b55d..294ddd9c5e 100644 --- a/api/libs/session_revocation_storage.py +++ b/api/libs/session_revocation_storage.py @@ -29,7 +29,6 @@ class NullSessionRevocationStorage(SessionRevocationStorage): class RedisSessionRevocationStorage(SessionRevocationStorage): - def __init__(self, key_prefix: str = "passport:blacklist:jti:") -> None: self.key_prefix = key_prefix diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py index 9e76e0c0fb..bc00bbaac1 100644 --- a/api/tests/unit_tests/libs/test_passport.py +++ b/api/tests/unit_tests/libs/test_passport.py @@ -229,9 +229,12 @@ class TestPassportService: # Mock redis_client (via storage layer) and force storage backend to redis mock_redis = MagicMock() import libs.session_revocation_storage as srs + srs._singleton = None - with patch("libs.session_revocation_storage.dify_config") as s_cfg, \ - patch("libs.session_revocation_storage.redis_client", mock_redis): + with ( + patch("libs.session_revocation_storage.dify_config") as s_cfg, + patch("libs.session_revocation_storage.redis_client", mock_redis), + ): s_cfg.SESSION_REVOCATION_STORAGE = "redis" result = passport_service.revoke(token) @@ -256,9 +259,12 @@ class TestPassportService: mock_redis = MagicMock() import libs.session_revocation_storage as srs + srs._singleton = None - with patch("libs.session_revocation_storage.dify_config") as s_cfg, \ - patch("libs.session_revocation_storage.redis_client", mock_redis): + with ( + patch("libs.session_revocation_storage.dify_config") as s_cfg, + patch("libs.session_revocation_storage.redis_client", mock_redis), + ): s_cfg.SESSION_REVOCATION_STORAGE = "redis" result = passport_service.revoke(token) @@ -280,6 +286,7 @@ class TestPassportService: with patch("libs.session_revocation_storage.redis_client", mock_redis): import libs.session_revocation_storage as srs + srs._singleton = None with patch("libs.session_revocation_storage.dify_config") as s_cfg: s_cfg.SESSION_REVOCATION_STORAGE = "redis" @@ -305,6 +312,7 @@ class TestPassportService: with patch("libs.session_revocation_storage.redis_client", mock_redis): import libs.session_revocation_storage as srs + srs._singleton = None with patch("libs.session_revocation_storage.dify_config") as s_cfg: s_cfg.SESSION_REVOCATION_STORAGE = "redis" @@ -320,9 +328,12 @@ class TestPassportService: mock_redis = MagicMock() import libs.session_revocation_storage as srs + srs._singleton = None - with patch("libs.session_revocation_storage.dify_config") as s_cfg, \ - patch("libs.session_revocation_storage.redis_client", mock_redis): + with ( + patch("libs.session_revocation_storage.dify_config") as s_cfg, + patch("libs.session_revocation_storage.redis_client", mock_redis), + ): s_cfg.SESSION_REVOCATION_STORAGE = "redis" result = passport_service.revoke(invalid_token) diff --git a/api/tests/unit_tests/libs/test_session_revocation_storage.py b/api/tests/unit_tests/libs/test_session_revocation_storage.py index f70be44972..2927229273 100644 --- a/api/tests/unit_tests/libs/test_session_revocation_storage.py +++ b/api/tests/unit_tests/libs/test_session_revocation_storage.py @@ -47,4 +47,3 @@ class TestSessionRevocationStorage: cfg.SESSION_REVOCATION_STORAGE = "redis" inst = get_session_revocation_storage() assert isinstance(inst, RedisSessionRevocationStorage) -