refactor(api): type auth service credentials with TypedDict (#33867)

This commit is contained in:
BitToby 2026-03-24 06:22:17 +02:00 committed by GitHub
parent 0589fa423b
commit ecd3a964c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 24 additions and 16 deletions

View File

@ -1,8 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
class AuthCredentials(TypedDict):
auth_type: str
config: dict[str, Any]
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
self.credentials = credentials
@abstractmethod

View File

@ -1,9 +1,9 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
def __init__(self, provider: str, credentials: AuthCredentials):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@ -13,13 +13,13 @@ class ConcreteApiKeyAuth(ApiKeyAuthBase):
class TestApiKeyAuthBase:
def test_should_store_credentials_on_init(self):
"""Test that credentials are properly stored during initialization"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials == credentials
def test_should_not_instantiate_abstract_class(self):
"""Test that ApiKeyAuthBase cannot be instantiated directly"""
credentials = {"api_key": "test_key"}
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
with pytest.raises(TypeError) as exc_info:
ApiKeyAuthBase(credentials)
@ -29,7 +29,7 @@ class TestApiKeyAuthBase:
def test_should_allow_subclass_implementation(self):
"""Test that subclasses can properly implement the abstract method"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
auth = ConcreteApiKeyAuth(credentials)
# Should not raise any exception

View File

@ -58,7 +58,7 @@ class TestApiKeyAuthFactory:
mock_get_factory.return_value = mock_auth_class
# Act
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}})
result = factory.validate_credentials()
# Assert
@ -75,7 +75,7 @@ class TestApiKeyAuthFactory:
mock_get_factory.return_value = mock_auth_class
# Act & Assert
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}})
with pytest.raises(Exception) as exc_info:
factory.validate_credentials()
assert str(exc_info.value) == "Authentication error"