dify/api/dify_graph/nodes/document_extractor/node.py

773 lines
29 KiB
Python
Raw Normal View History

import csv
import io
import json
import logging
import os
import tempfile
import zipfile
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
import charset_normalizer
import docx
import pandas as pd
import pypandoc
import pypdfium2
import webvtt
import yaml
2025-03-19 02:24:35 +00:00
from docx.document import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod, file_manager
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.variables import ArrayFileSegment
from dify_graph.variables.segments import ArrayStringSegment, FileSegment
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
node_type = NodeType.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
return "1"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
unstructured_api_config: UnstructuredApiConfig | None = None,
http_client: HttpClientProtocol,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
self._http_client = http_client
def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
error_message = f"File variable not found for selector: {variable_selector}"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
error_message = f"Variable {variable_selector} is not an ArrayFileSegment"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
value = variable.value
inputs = {"variable_selector": variable_selector}
process_data = {"documents": value if isinstance(value, list) else [value]}
try:
if isinstance(value, list):
extracted_text_list = [
_extract_text_from_file(
self._http_client, file, unstructured_api_config=self._unstructured_api_config
)
for file in value
]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(
self._http_client, value, unstructured_api_config=self._unstructured_api_config
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text},
)
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(
*,
file_content: bytes,
mime_type: str,
unstructured_api_config: UnstructuredApiConfig,
) -> str:
"""Extract text from a file based on its MIME type."""
match mime_type:
case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml":
return _extract_text_from_plain_text(file_content)
case "application/pdf":
return _extract_text_from_pdf(file_content)
case "application/msword":
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return _extract_text_from_docx(file_content)
case "text/csv":
return _extract_text_from_csv(file_content)
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
return _extract_text_from_excel(file_content)
case "application/vnd.ms-powerpoint":
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
case "application/epub+zip":
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
case "message/rfc822":
return _extract_text_from_eml(file_content)
case "application/vnd.ms-outlook":
return _extract_text_from_msg(file_content)
case "application/json":
return _extract_text_from_json(file_content)
case "application/x-yaml" | "text/yaml":
return _extract_text_from_yaml(file_content)
case "text/vtt":
return _extract_text_from_vtt(file_content)
case "text/properties":
return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
def _extract_text_by_file_extension(
*,
file_content: bytes,
file_extension: str,
unstructured_api_config: UnstructuredApiConfig,
) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
case (
".txt"
| ".markdown"
| ".md"
| ".mdx"
| ".html"
| ".htm"
| ".xml"
| ".c"
| ".h"
| ".cpp"
| ".hpp"
| ".cc"
| ".cxx"
| ".c++"
| ".py"
| ".js"
| ".ts"
| ".jsx"
| ".tsx"
| ".java"
| ".php"
| ".rb"
| ".go"
| ".rs"
| ".swift"
| ".kt"
| ".scala"
| ".sh"
| ".bash"
| ".bat"
| ".ps1"
| ".sql"
| ".r"
| ".m"
| ".pl"
| ".lua"
| ".vim"
| ".asm"
| ".s"
| ".css"
| ".scss"
| ".less"
| ".sass"
| ".ini"
| ".cfg"
| ".conf"
| ".toml"
| ".env"
| ".log"
| ".vtt"
):
return _extract_text_from_plain_text(file_content)
case ".json":
return _extract_text_from_json(file_content)
case ".yaml" | ".yml":
return _extract_text_from_yaml(file_content)
case ".pdf":
return _extract_text_from_pdf(file_content)
case ".doc":
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
case ".docx":
return _extract_text_from_docx(file_content)
case ".csv":
return _extract_text_from_csv(file_content)
case ".xls" | ".xlsx":
return _extract_text_from_excel(file_content)
case ".ppt":
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
case ".pptx":
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
case ".epub":
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
case ".eml":
return _extract_text_from_eml(file_content)
case ".msg":
return _extract_text_from_msg(file_content)
case ".properties":
return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}")
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
return file_content.decode(encoding, errors="ignore")
except (UnicodeDecodeError, LookupError) as e:
# If decoding fails, try with utf-8 as last resort
try:
return file_content.decode("utf-8", errors="ignore")
except UnicodeDecodeError:
raise TextExtractionError(f"Failed to decode plain text file: {e}") from e
def _extract_text_from_json(file_content: bytes) -> str:
try:
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
json_data = json.loads(file_content.decode(encoding, errors="ignore"))
return json.dumps(json_data, indent=2, ensure_ascii=False)
except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e:
# If decoding fails, try with utf-8 as last resort
try:
json_data = json.loads(file_content.decode("utf-8", errors="ignore"))
return json.dumps(json_data, indent=2, ensure_ascii=False)
except (UnicodeDecodeError, json.JSONDecodeError):
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
def _extract_text_from_pdf(file_content: bytes) -> str:
try:
pdf_file = io.BytesIO(file_content)
pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True)
text = ""
for page in pdf_document:
text_page = page.get_textpage()
text += text_page.get_text_range()
text_page.close()
page.close()
return text
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
"""
Extract text from a DOC file.
"""
from unstructured.partition.api import partition_via_api
if not unstructured_api_config.api_url:
raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.")
api_key = unstructured_api_config.api_key or ""
try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=unstructured_api_config.api_url,
api_key=api_key,
)
os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e
def parser_docx_part(block, doc: Document, content_items, i):
2025-03-19 02:24:35 +00:00
if isinstance(block, CT_P):
content_items.append((i, "paragraph", Paragraph(block, doc)))
elif isinstance(block, CT_Tbl):
content_items.append((i, "table", Table(block, doc)))
def _normalize_docx_zip(file_content: bytes) -> bytes:
"""
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
ZIP entry names use backslash (\\) as path separator instead of the forward
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
"word\\document.xml" is never found when python-docx looks for
"word/document.xml", which triggers a KeyError about a missing relationship.
This function rewrites the ZIP in-memory, normalizing all entry names to
use forward slashes without touching any actual document content.
"""
try:
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
out_buf = io.BytesIO()
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
for item in zin.infolist():
data = zin.read(item.filename)
# Normalize backslash path separators to forward slash
item.filename = item.filename.replace("\\", "/")
zout.writestr(item, data)
return out_buf.getvalue()
except zipfile.BadZipFile:
# Not a valid zip — return as-is and let python-docx report the real error
return file_content
def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
For now support only paragraph and table add more if needed
"""
try:
doc_file = io.BytesIO(file_content)
try:
doc = docx.Document(doc_file)
except Exception as e:
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
# Some DOCX files exported by tools like Evernote on Windows use
# backslash path separators in ZIP entries and/or single-quoted XML
# attributes, both of which break python-docx on Linux. Normalize and retry.
file_content = _normalize_docx_zip(file_content)
doc = docx.Document(io.BytesIO(file_content))
text = []
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
2025-03-19 02:24:35 +00:00
it = iter(doc.element.body)
part = next(it, None)
i = 0
while part is not None:
parser_docx_part(part, doc, content_items, i)
2025-03-19 02:24:35 +00:00
i = i + 1
part = next(it, None)
# Process sorted content
for _, item_type, item in content_items:
if item_type == "paragraph":
if isinstance(item, Table):
continue
text.append(item.text)
elif item_type == "table":
# Process tables
if not isinstance(item, Table):
continue
try:
# Check if any cell in the table has text
has_content = False
for row in item.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
markdown_table = f"| {' | '.join(cell_texts)} |\n"
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
for row in item.rows[1:]:
# Replace newlines with <br> in each cell
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
markdown_table += "| " + " | ".join(row_cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning("Failed to extract table from DOC: %s", e)
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes:
"""Download the content of a file based on its transfer method."""
try:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
if file.remote_url is None:
raise FileDownloadError("Missing URL for remote file")
response = http_client.get(file.remote_url)
response.raise_for_status()
return response.content
else:
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
def _extract_text_from_file(
http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig
) -> str:
file_content = _download_file_content(http_client, file)
if file.extension:
extracted_text = _extract_text_by_file_extension(
file_content=file_content,
file_extension=file.extension,
unstructured_api_config=unstructured_api_config,
)
elif file.mime_type:
extracted_text = _extract_text_by_mime_type(
file_content=file_content,
mime_type=file.mime_type,
unstructured_api_config=unstructured_api_config,
)
else:
raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing")
return extracted_text
def _extract_text_from_csv(file_content: bytes) -> str:
try:
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
try:
csv_file = io.StringIO(file_content.decode(encoding, errors="ignore"))
except (UnicodeDecodeError, LookupError):
# If decoding fails, try with utf-8 as last resort
csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore"))
csv_reader = csv.reader(csv_file)
rows = list(csv_reader)
if not rows:
return ""
# Combine multi-line text in the header row
header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]]
# Create Markdown table
markdown_table = "| " + " | ".join(header_row) + " |\n"
markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n"
# Process each data row and combine multi-line text in each cell
for row in rows[1:]:
processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row]
markdown_table += "| " + " | ".join(processed_row) + " |\n"
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
def _extract_text_from_excel(file_content: bytes) -> str:
"""Extract text from an Excel file using pandas."""
def _construct_markdown_table(df: pd.DataFrame) -> str:
"""Manually construct a Markdown table from a DataFrame."""
# Construct the header row
header_row = "| " + " | ".join(df.columns) + " |"
# Construct the separator row
separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |"
# Construct the data rows
data_rows = []
for _, row in df.iterrows():
data_row = "| " + " | ".join(map(str, row)) + " |"
data_rows.append(data_row)
# Combine all rows into a single string
markdown_table = "\n".join([header_row, separator_row] + data_rows)
return markdown_table
try:
excel_file = pd.ExcelFile(io.BytesIO(file_content))
markdown_table = ""
for sheet_name in excel_file.sheet_names:
try:
df = excel_file.parse(sheet_name=sheet_name)
df.dropna(how="all", inplace=True)
# Combine multi-line text in each cell into a single line
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
# Combine multi-line text in column names into a single line
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
# Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n"
except Exception:
continue
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.ppt import partition_ppt
api_key = unstructured_api_config.api_key or ""
try:
if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=unstructured_api_config.api_url,
api_key=api_key,
)
os.unlink(temp_file.name)
else:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
api_key = unstructured_api_config.api_key or ""
try:
if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=unstructured_api_config.api_url,
api_key=api_key,
)
os.unlink(temp_file.name)
else:
with io.BytesIO(file_content) as file:
elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
2025-03-26 14:34:10 +00:00
from unstructured.partition.api import partition_via_api
from unstructured.partition.epub import partition_epub
api_key = unstructured_api_config.api_key or ""
try:
if unstructured_api_config.api_url:
2025-03-26 14:34:10 +00:00
with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=unstructured_api_config.api_url,
api_key=api_key,
2025-03-26 14:34:10 +00:00
)
os.unlink(temp_file.name)
else:
pypandoc.download_pandoc()
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e
def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e
def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e
def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
text = _extract_text_from_plain_text(vtt_bytes)
# remove bom
text = text.lstrip("\ufeff")
raw_results = []
for caption in webvtt.from_string(text):
raw_results.append((caption.voice, caption.text))
# Merge consecutive utterances by the same speaker
merged_results = []
if raw_results:
current_speaker, current_text = raw_results[0]
for i in range(1, len(raw_results)):
spk, txt = raw_results[i]
if spk is None:
merged_results.append((None, current_text))
continue
if spk == current_speaker:
# If it is the same speaker, merge the utterances (joined by space)
current_text += " " + txt
else:
# If the speaker changes, register the utterance so far and move on
merged_results.append((current_speaker, current_text))
current_speaker, current_text = spk, txt
# Add the last element
merged_results.append((current_speaker, current_text))
else:
merged_results = raw_results
# Return the result in the specified format: Speaker "text" style
formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results]
return "\n".join(formatted)
def _extract_text_from_properties(file_content: bytes) -> str:
try:
text = _extract_text_from_plain_text(file_content)
lines = text.splitlines()
result = []
for line in lines:
line = line.strip()
# Preserve comments and empty lines
if not line or line.startswith("#") or line.startswith("!"):
result.append(line)
continue
if "=" in line:
key, value = line.split("=", 1)
elif ":" in line:
key, value = line.split(":", 1)
else:
key, value = line, ""
result.append(f"{key.strip()}: {value.strip()}")
return "\n".join(result)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e