diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py index dd74a8f1b5..2e1b723e82 100644 --- a/api/services/auth/api_key_auth_base.py +++ b/api/services/auth/api_key_auth_base.py @@ -1,8 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any + +from typing_extensions import TypedDict + + +class AuthCredentials(TypedDict): + auth_type: str + config: dict[str, Any] class ApiKeyAuthBase(ABC): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): self.credentials = credentials @abstractmethod diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 7ae31b0768..6e183b70e3 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,9 +1,9 @@ -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): + def __init__(self, provider: str, credentials: AuthCredentials): auth_factory = self.get_apikey_auth_factory(provider) self.auth = auth_factory(credentials) diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index b002706931..c9e5610aea 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index b2d28a83d1..cbdc908690 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -3,11 +3,11 @@ from urllib.parse import urljoin import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class WatercrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "x-api-key": diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py index b5d91ef3fb..388504c07f 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py @@ -13,13 +13,13 @@ class ConcreteApiKeyAuth(ApiKeyAuthBase): class TestApiKeyAuthBase: def test_should_store_credentials_on_init(self): """Test that credentials are properly stored during initialization""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) assert auth.credentials == credentials def test_should_not_instantiate_abstract_class(self): """Test that ApiKeyAuthBase cannot be instantiated directly""" - credentials = {"api_key": "test_key"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} with pytest.raises(TypeError) as exc_info: ApiKeyAuthBase(credentials) @@ -29,7 +29,7 @@ class TestApiKeyAuthBase: def test_should_allow_subclass_implementation(self): """Test that subclasses can properly implement the abstract method""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) # Should not raise any exception diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py index 60af6e20c2..b1f7cf24f3 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py @@ -58,7 +58,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) result = factory.validate_credentials() # Assert @@ -75,7 +75,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act & Assert - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) with pytest.raises(Exception) as exc_info: factory.validate_credentials() assert str(exc_info.value) == "Authentication error"