mirror of https://github.com/langgenius/dify.git
refactor(api): type Firecrawl API responses with TypedDict (#33691)
This commit is contained in:
parent
146f8fac45
commit
b2a388b7bf
|
|
@ -1,12 +1,38 @@
|
|||
import json
|
||||
import time
|
||||
from typing import Any, cast
|
||||
from typing import Any, NotRequired, cast
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class FirecrawlDocumentData(TypedDict):
|
||||
title: str | None
|
||||
description: str | None
|
||||
source_url: str | None
|
||||
markdown: str | None
|
||||
|
||||
|
||||
class CrawlStatusResponse(TypedDict):
|
||||
status: str
|
||||
total: int | None
|
||||
current: int | None
|
||||
data: list[FirecrawlDocumentData]
|
||||
|
||||
|
||||
class MapResponse(TypedDict):
|
||||
success: bool
|
||||
links: list[str]
|
||||
|
||||
|
||||
class SearchResponse(TypedDict):
|
||||
success: bool
|
||||
data: list[dict[str, Any]]
|
||||
warning: NotRequired[str]
|
||||
|
||||
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
|
|
@ -14,7 +40,7 @@ class FirecrawlApp:
|
|||
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict[str, Any]:
|
||||
def scrape_url(self, url, params=None) -> FirecrawlDocumentData:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
|
|
@ -32,9 +58,7 @@ class FirecrawlApp:
|
|||
return self._extract_common_fields(data)
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "scrape URL")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
|
||||
|
|
@ -51,7 +75,7 @@ class FirecrawlApp:
|
|||
self._handle_error(response, "start crawl job")
|
||||
return "" # unreachable
|
||||
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> MapResponse:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
|
||||
headers = self._prepare_headers()
|
||||
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
|
||||
|
|
@ -60,14 +84,12 @@ class FirecrawlApp:
|
|||
json_data.update(params)
|
||||
response = self._post_request(self._build_url("v2/map"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
return cast(dict[str, Any], response.json())
|
||||
return cast(MapResponse, response.json())
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "start map job")
|
||||
return {}
|
||||
else:
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
def check_crawl_status(self, job_id) -> CrawlStatusResponse:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
|
||||
if response.status_code == 200:
|
||||
|
|
@ -77,7 +99,7 @@ class FirecrawlApp:
|
|||
if total == 0:
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list = []
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = self._extract_common_fields(item)
|
||||
|
|
@ -95,13 +117,15 @@ class FirecrawlApp:
|
|||
return self._format_crawl_status_response(
|
||||
crawl_status_response.get("status"), crawl_status_response, []
|
||||
)
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
return {} # unreachable
|
||||
self._handle_error(response, "check crawl status")
|
||||
raise RuntimeError("unreachable: _handle_error always raises")
|
||||
|
||||
def _format_crawl_status_response(
|
||||
self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
self,
|
||||
status: str,
|
||||
crawl_status_response: dict[str, Any],
|
||||
url_data_list: list[FirecrawlDocumentData],
|
||||
) -> CrawlStatusResponse:
|
||||
return {
|
||||
"status": status,
|
||||
"total": crawl_status_response.get("total"),
|
||||
|
|
@ -109,7 +133,7 @@ class FirecrawlApp:
|
|||
"data": url_data_list,
|
||||
}
|
||||
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]:
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> FirecrawlDocumentData:
|
||||
return {
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
|
|
@ -117,7 +141,7 @@ class FirecrawlApp:
|
|||
"markdown": item.get("markdown"),
|
||||
}
|
||||
|
||||
def _prepare_headers(self) -> dict[str, Any]:
|
||||
def _prepare_headers(self) -> dict[str, str]:
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
|
|
@ -150,10 +174,10 @@ class FirecrawlApp:
|
|||
error_message = response.text or "Unknown error occurred"
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
||||
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> SearchResponse:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
json_data: dict[str, Any] = {
|
||||
"query": query,
|
||||
"limit": 5,
|
||||
"lang": "en",
|
||||
|
|
@ -170,12 +194,10 @@ class FirecrawlApp:
|
|||
json_data.update(params)
|
||||
response = self._post_request(self._build_url("v2/search"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
response_data: SearchResponse = response.json()
|
||||
if not response_data.get("success"):
|
||||
raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
|
||||
return cast(dict[str, Any], response_data)
|
||||
return response_data
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "perform search")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import httpx
|
|||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData
|
||||
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
|
@ -270,13 +270,13 @@ class WebsiteService:
|
|||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
"status": result.get("status", "active"),
|
||||
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data: dict[str, Any] = {
|
||||
"status": result["status"],
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
"total": result["total"] or 0,
|
||||
"current": result["current"] or 0,
|
||||
"data": result["data"],
|
||||
}
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
|
|
@ -343,7 +343,7 @@ class WebsiteService:
|
|||
|
||||
@classmethod
|
||||
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
crawl_data: list[FirecrawlDocumentData] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
stored_data = storage.load_once(file_key)
|
||||
|
|
@ -352,13 +352,13 @@ class WebsiteService:
|
|||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
if result["status"] != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
crawl_data = result.get("data")
|
||||
crawl_data = result["data"]
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
if item["source_url"] == url:
|
||||
return dict(item)
|
||||
return None
|
||||
|
||||
|
|
@ -416,7 +416,7 @@ class WebsiteService:
|
|||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
params = {"onlyMainContent": request.only_main_content}
|
||||
return firecrawl_app.scrape_url(url=request.url, params=params)
|
||||
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -104,10 +104,11 @@ class TestFirecrawlApp:
|
|||
|
||||
def test_map_known_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("map error"))
|
||||
mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"}))
|
||||
|
||||
assert app.map("https://example.com") == {}
|
||||
with pytest.raises(Exception, match="map error"):
|
||||
app.map("https://example.com")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_map_unknown_error_raises(self, mocker: MockerFixture):
|
||||
|
|
@ -177,10 +178,11 @@ class TestFirecrawlApp:
|
|||
|
||||
def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl error"))
|
||||
mocker.patch("httpx.get", return_value=_response(500, {"error": "server"}))
|
||||
|
||||
assert app.check_crawl_status("job-1") == {}
|
||||
with pytest.raises(Exception, match="crawl error"):
|
||||
app.check_crawl_status("job-1")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture):
|
||||
|
|
@ -272,9 +274,10 @@ class TestFirecrawlApp:
|
|||
|
||||
def test_search_known_http_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("search error"))
|
||||
mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"}))
|
||||
assert app.search("python") == {}
|
||||
with pytest.raises(Exception, match="search error"):
|
||||
app.search("python")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_search_unknown_http_error(self, mocker: MockerFixture):
|
||||
|
|
|
|||
|
|
@ -443,7 +443,7 @@ def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monk
|
|||
|
||||
def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed"}
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 1, "current": 1, "data": []}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
redis_mock = MagicMock()
|
||||
|
|
|
|||
Loading…
Reference in New Issue