dify/api/tests/integration_tests/vdb/__mock/hologres.py

210 lines
7.4 KiB
Python

import json
import os
from typing import Any
import holo_search_sdk as holo
import pytest
from _pytest.monkeypatch import MonkeyPatch
from psycopg import sql as psql
# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}}
_mock_tables: dict[str, dict[str, dict[str, Any]]] = {}
class MockSearchQuery:
"""Mock query builder for search_vector and search_text results."""
def __init__(self, table_name: str, search_type: str):
self._table_name = table_name
self._search_type = search_type
self._limit_val = 10
self._filter_sql = None
def select(self, columns):
return self
def limit(self, n):
self._limit_val = n
return self
def where(self, filter_sql):
self._filter_sql = filter_sql
return self
def _apply_filter(self, row: dict[str, Any]) -> bool:
"""Apply the filter SQL to check if a row matches."""
if self._filter_sql is None:
return True
# Extract literals (the document IDs) from the filter SQL
# Filter format: meta->>'document_id' IN ('doc1', 'doc2')
literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"]
if not literals:
return True
# Get the document_id from the row's meta field
meta = row.get("meta", "{}")
if isinstance(meta, str):
meta = json.loads(meta)
doc_id = meta.get("document_id")
return doc_id in literals
def fetchall(self):
data = _mock_tables.get(self._table_name, {})
results = []
for row in list(data.values())[: self._limit_val]:
# Apply filter if present
if not self._apply_filter(row):
continue
if self._search_type == "vector":
# row format expected by _process_vector_results: (distance, id, text, meta)
results.append((0.1, row["id"], row["text"], row["meta"]))
else:
# row format expected by _process_full_text_results: (id, text, meta, embedding, score)
results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9))
return results
class MockTable:
"""Mock table object returned by client.open_table()."""
def __init__(self, table_name: str):
self._table_name = table_name
def upsert_multi(self, index_column, values, column_names, update=True, update_columns=None):
if self._table_name not in _mock_tables:
_mock_tables[self._table_name] = {}
id_idx = column_names.index("id")
for row in values:
doc_id = row[id_idx]
_mock_tables[self._table_name][doc_id] = dict(zip(column_names, row))
def search_vector(self, vector, column, distance_method, output_name):
return MockSearchQuery(self._table_name, "vector")
def search_text(self, column, expression, return_score=False, return_score_name="score", return_all_columns=False):
return MockSearchQuery(self._table_name, "text")
def set_vector_index(
self, column, distance_method, base_quantization_type, max_degree, ef_construction, use_reorder
):
pass
def create_text_index(self, index_name, column, tokenizer):
pass
def _extract_sql_template(query) -> str:
"""Extract the SQL template string from a psycopg Composed object."""
if isinstance(query, psql.Composed):
for part in query:
if isinstance(part, psql.SQL):
return part._obj
if isinstance(query, psql.SQL):
return query._obj
return ""
def _extract_identifiers_and_literals(query) -> list[Any]:
"""Extract Identifier and Literal values from a psycopg Composed object."""
values: list[Any] = []
if isinstance(query, psql.Composed):
for part in query:
if isinstance(part, psql.Identifier):
values.append(("ident", part._obj[0] if part._obj else ""))
elif isinstance(part, psql.Literal):
values.append(("literal", part._obj))
elif isinstance(part, psql.Composed):
# Handles SQL(...).join(...) for IN clauses
for sub in part:
if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj))
return values
class MockHologresClient:
"""Mock holo_search_sdk client that stores data in memory."""
def connect(self):
pass
def check_table_exist(self, table_name):
return table_name in _mock_tables
def open_table(self, table_name):
return MockTable(table_name)
def execute(self, query, fetch_result=False):
template = _extract_sql_template(query)
params = _extract_identifiers_and_literals(query)
if "CREATE TABLE" in template.upper():
# Extract table name from first identifier
table_name = next((v for t, v in params if t == "ident"), "unknown")
if table_name not in _mock_tables:
_mock_tables[table_name] = {}
return None
if "SELECT 1" in template:
# text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1
table_name = next((v for t, v in params if t == "ident"), "")
doc_id = next((v for t, v in params if t == "literal"), "")
data = _mock_tables.get(table_name, {})
return [(1,)] if doc_id in data else []
if "SELECT id" in template:
# get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value}
table_name = next((v for t, v in params if t == "ident"), "")
literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else ""
value = literals[1] if len(literals) > 1 else ""
data = _mock_tables.get(table_name, {})
return [(doc_id,) for doc_id, row in data.items() if json.loads(row.get("meta", "{}")).get(key) == value]
if "DELETE" in template.upper():
table_name = next((v for t, v in params if t == "ident"), "")
if "id IN" in template:
# delete_by_ids
ids_to_delete = [v for t, v in params if t == "literal"]
for did in ids_to_delete:
_mock_tables.get(table_name, {}).pop(did, None)
elif "meta->>" in template:
# delete_by_metadata_field
literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else ""
value = literals[1] if len(literals) > 1 else ""
data = _mock_tables.get(table_name, {})
to_remove = [
doc_id for doc_id, row in data.items() if json.loads(row.get("meta", "{}")).get(key) == value
]
for did in to_remove:
data.pop(did, None)
return None
return [] if fetch_result else None
def drop_table(self, table_name):
_mock_tables.pop(table_name, None)
def mock_connect(**kwargs):
"""Replacement for holo_search_sdk.connect() that returns a mock client."""
return MockHologresClient()
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_hologres_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(holo, "connect", mock_connect)
yield
if MOCK:
_mock_tables.clear()
monkeypatch.undo()