refactor(api): type WaterCrawl API responses with TypedDict (#33700)

This commit is contained in:
BitToby 2026-03-19 03:35:44 +02:00 committed by GitHub
parent 9367020bfd
commit 9ff0d9df88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 76 additions and 24 deletions

View File

@ -1,10 +1,11 @@
import json
from collections.abc import Generator
from typing import Union
from typing import Any, Union
from urllib.parse import urljoin
import httpx
from httpx import Response
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
@ -13,6 +14,27 @@ from core.rag.extractor.watercrawl.exceptions import (
)
class SpiderOptions(TypedDict):
max_depth: int
page_limit: int
allowed_domains: list[str]
exclude_paths: list[str]
include_paths: list[str]
class PageOptions(TypedDict):
exclude_tags: list[str]
include_tags: list[str]
wait_time: int
include_html: bool
only_main_content: bool
include_links: bool
timeout: int
accept_cookies_selector: str
locale: str
actions: list[Any]
class BaseAPIClient:
def __init__(self, api_key, base_url):
self.api_key = api_key
@ -121,9 +143,9 @@ class WaterCrawlAPIClient(BaseAPIClient):
def create_crawl_request(
self,
url: Union[list, str] | None = None,
spider_options: dict | None = None,
page_options: dict | None = None,
plugin_options: dict | None = None,
spider_options: SpiderOptions | None = None,
page_options: PageOptions | None = None,
plugin_options: dict[str, Any] | None = None,
):
data = {
# 'urls': url if isinstance(url, list) else [url],
@ -176,8 +198,8 @@ class WaterCrawlAPIClient(BaseAPIClient):
def scrape_url(
self,
url: str,
page_options: dict | None = None,
plugin_options: dict | None = None,
page_options: PageOptions | None = None,
plugin_options: dict[str, Any] | None = None,
sync: bool = True,
prefetched: bool = True,
):

View File

@ -2,16 +2,39 @@ from collections.abc import Generator
from datetime import datetime
from typing import Any
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient
class WatercrawlDocumentData(TypedDict):
title: str | None
description: str | None
source_url: str | None
markdown: str | None
class CrawlJobResponse(TypedDict):
status: str
job_id: str | None
class WatercrawlCrawlStatusResponse(TypedDict):
status: str
job_id: str | None
total: int
current: int
data: list[WatercrawlDocumentData]
time_consuming: float
class WaterCrawlProvider:
def __init__(self, api_key, base_url: str | None = None):
self.client = WaterCrawlAPIClient(api_key, base_url)
def crawl_url(self, url, options: dict | Any | None = None):
def crawl_url(self, url: str, options: dict[str, Any] | None = None) -> CrawlJobResponse:
options = options or {}
spider_options = {
spider_options: SpiderOptions = {
"max_depth": 1,
"page_limit": 1,
"allowed_domains": [],
@ -25,7 +48,7 @@ class WaterCrawlProvider:
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
wait_time = options.get("wait_time", 1000)
page_options = {
page_options: PageOptions = {
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
@ -41,9 +64,9 @@ class WaterCrawlProvider:
return {"status": "active", "job_id": result.get("uuid")}
def get_crawl_status(self, crawl_request_id):
def get_crawl_status(self, crawl_request_id: str) -> WatercrawlCrawlStatusResponse:
response = self.client.get_crawl_request(crawl_request_id)
data = []
data: list[WatercrawlDocumentData] = []
if response["status"] in ["new", "running"]:
status = "active"
else:
@ -67,7 +90,7 @@ class WaterCrawlProvider:
"time_consuming": time_consuming,
}
def get_crawl_url_data(self, job_id, url) -> dict | None:
def get_crawl_url_data(self, job_id: str, url: str) -> WatercrawlDocumentData | None:
if not job_id:
return self.scrape_url(url)
@ -82,11 +105,11 @@ class WaterCrawlProvider:
return None
def scrape_url(self, url: str):
def scrape_url(self, url: str) -> WatercrawlDocumentData:
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
return self._structure_data(response)
def _structure_data(self, result_object: dict):
def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData:
if isinstance(result_object.get("result", {}), str):
raise ValueError("Invalid result object. Expected a dictionary.")
@ -98,7 +121,9 @@ class WaterCrawlProvider:
"markdown": result_object.get("result", {}).get("markdown"),
}
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
def _get_results(
self, crawl_request_id: str, query_params: dict | None = None
) -> Generator[WatercrawlDocumentData, None, None]:
page = 0
page_size = 100

View File

@ -216,8 +216,10 @@ class WebsiteService:
"max_depth": request.options.max_depth,
"use_sitemap": request.options.use_sitemap,
}
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
url=request.url, options=options
return dict(
WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
url=request.url, options=options
)
)
@classmethod
@ -289,8 +291,8 @@ class WebsiteService:
return crawl_status_data
@classmethod
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))
@classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
@ -363,8 +365,11 @@ class WebsiteService:
return None
@classmethod
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
def _get_watercrawl_url_data(
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
) -> dict[str, Any] | None:
result = WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
return dict(result) if result is not None else None
@classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
@ -419,5 +424,5 @@ class WebsiteService:
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
@classmethod
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
return dict(WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url))