mirror of https://github.com/langgenius/dify.git
refactor(api): type WaterCrawl API responses with TypedDict (#33700)
This commit is contained in:
parent
9367020bfd
commit
9ff0d9df88
|
|
@ -1,10 +1,11 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlAuthenticationError,
|
||||
|
|
@ -13,6 +14,27 @@ from core.rag.extractor.watercrawl.exceptions import (
|
|||
)
|
||||
|
||||
|
||||
class SpiderOptions(TypedDict):
|
||||
max_depth: int
|
||||
page_limit: int
|
||||
allowed_domains: list[str]
|
||||
exclude_paths: list[str]
|
||||
include_paths: list[str]
|
||||
|
||||
|
||||
class PageOptions(TypedDict):
|
||||
exclude_tags: list[str]
|
||||
include_tags: list[str]
|
||||
wait_time: int
|
||||
include_html: bool
|
||||
only_main_content: bool
|
||||
include_links: bool
|
||||
timeout: int
|
||||
accept_cookies_selector: str
|
||||
locale: str
|
||||
actions: list[Any]
|
||||
|
||||
|
||||
class BaseAPIClient:
|
||||
def __init__(self, api_key, base_url):
|
||||
self.api_key = api_key
|
||||
|
|
@ -121,9 +143,9 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
|||
def create_crawl_request(
|
||||
self,
|
||||
url: Union[list, str] | None = None,
|
||||
spider_options: dict | None = None,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
spider_options: SpiderOptions | None = None,
|
||||
page_options: PageOptions | None = None,
|
||||
plugin_options: dict[str, Any] | None = None,
|
||||
):
|
||||
data = {
|
||||
# 'urls': url if isinstance(url, list) else [url],
|
||||
|
|
@ -176,8 +198,8 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
|||
def scrape_url(
|
||||
self,
|
||||
url: str,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
page_options: PageOptions | None = None,
|
||||
plugin_options: dict[str, Any] | None = None,
|
||||
sync: bool = True,
|
||||
prefetched: bool = True,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -2,16 +2,39 @@ from collections.abc import Generator
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient
|
||||
|
||||
|
||||
class WatercrawlDocumentData(TypedDict):
|
||||
title: str | None
|
||||
description: str | None
|
||||
source_url: str | None
|
||||
markdown: str | None
|
||||
|
||||
|
||||
class CrawlJobResponse(TypedDict):
|
||||
status: str
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class WatercrawlCrawlStatusResponse(TypedDict):
|
||||
status: str
|
||||
job_id: str | None
|
||||
total: int
|
||||
current: int
|
||||
data: list[WatercrawlDocumentData]
|
||||
time_consuming: float
|
||||
|
||||
|
||||
class WaterCrawlProvider:
|
||||
def __init__(self, api_key, base_url: str | None = None):
|
||||
self.client = WaterCrawlAPIClient(api_key, base_url)
|
||||
|
||||
def crawl_url(self, url, options: dict | Any | None = None):
|
||||
def crawl_url(self, url: str, options: dict[str, Any] | None = None) -> CrawlJobResponse:
|
||||
options = options or {}
|
||||
spider_options = {
|
||||
spider_options: SpiderOptions = {
|
||||
"max_depth": 1,
|
||||
"page_limit": 1,
|
||||
"allowed_domains": [],
|
||||
|
|
@ -25,7 +48,7 @@ class WaterCrawlProvider:
|
|||
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
|
||||
|
||||
wait_time = options.get("wait_time", 1000)
|
||||
page_options = {
|
||||
page_options: PageOptions = {
|
||||
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
|
||||
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
|
||||
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
|
||||
|
|
@ -41,9 +64,9 @@ class WaterCrawlProvider:
|
|||
|
||||
return {"status": "active", "job_id": result.get("uuid")}
|
||||
|
||||
def get_crawl_status(self, crawl_request_id):
|
||||
def get_crawl_status(self, crawl_request_id: str) -> WatercrawlCrawlStatusResponse:
|
||||
response = self.client.get_crawl_request(crawl_request_id)
|
||||
data = []
|
||||
data: list[WatercrawlDocumentData] = []
|
||||
if response["status"] in ["new", "running"]:
|
||||
status = "active"
|
||||
else:
|
||||
|
|
@ -67,7 +90,7 @@ class WaterCrawlProvider:
|
|||
"time_consuming": time_consuming,
|
||||
}
|
||||
|
||||
def get_crawl_url_data(self, job_id, url) -> dict | None:
|
||||
def get_crawl_url_data(self, job_id: str, url: str) -> WatercrawlDocumentData | None:
|
||||
if not job_id:
|
||||
return self.scrape_url(url)
|
||||
|
||||
|
|
@ -82,11 +105,11 @@ class WaterCrawlProvider:
|
|||
|
||||
return None
|
||||
|
||||
def scrape_url(self, url: str):
|
||||
def scrape_url(self, url: str) -> WatercrawlDocumentData:
|
||||
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
|
||||
return self._structure_data(response)
|
||||
|
||||
def _structure_data(self, result_object: dict):
|
||||
def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData:
|
||||
if isinstance(result_object.get("result", {}), str):
|
||||
raise ValueError("Invalid result object. Expected a dictionary.")
|
||||
|
||||
|
|
@ -98,7 +121,9 @@ class WaterCrawlProvider:
|
|||
"markdown": result_object.get("result", {}).get("markdown"),
|
||||
}
|
||||
|
||||
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
|
||||
def _get_results(
|
||||
self, crawl_request_id: str, query_params: dict | None = None
|
||||
) -> Generator[WatercrawlDocumentData, None, None]:
|
||||
page = 0
|
||||
page_size = 100
|
||||
|
||||
|
|
|
|||
|
|
@ -216,8 +216,10 @@ class WebsiteService:
|
|||
"max_depth": request.options.max_depth,
|
||||
"use_sitemap": request.options.use_sitemap,
|
||||
}
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
return dict(
|
||||
WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -289,8 +291,8 @@ class WebsiteService:
|
|||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
|
|
@ -363,8 +365,11 @@ class WebsiteService:
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
def _get_watercrawl_url_data(
|
||||
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
result = WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
return dict(result) if result is not None else None
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
|
|
@ -419,5 +424,5 @@ class WebsiteService:
|
|||
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url))
|
||||
|
|
|
|||
Loading…
Reference in New Issue