mirror of https://github.com/langgenius/dify.git
fix(api): return inserted ids from Chroma and Clickzetta add_texts (#33065)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
49dcf5e0d9
commit
f751864ab3
|
|
@ -65,7 +65,7 @@ class ChromaVector(BaseVector):
|
|||
self._client.get_or_create_collection(collection_name)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
|
@ -73,6 +73,7 @@ class ChromaVector(BaseVector):
|
|||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
# FIXME: chromadb using numpy array, fix the type error later
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
return uuids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
|
|
|
|||
|
|
@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector):
|
|||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
return
|
||||
return []
|
||||
|
||||
batch_size = self._config.batch_size
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
added_ids = []
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
batch_doc_ids = []
|
||||
for doc in batch_docs:
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))))
|
||||
added_ids.extend(batch_doc_ids)
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
self._execute_write(
|
||||
self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches
|
||||
)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _insert_batch(
|
||||
self,
|
||||
batch_docs: list[Document],
|
||||
batch_embeddings: list[list[float]],
|
||||
batch_doc_ids: list[str],
|
||||
batch_index: int,
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
|
|
@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector):
|
|||
data_rows = []
|
||||
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
|
||||
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata or {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
|
||||
# Fast path for JSON serialization
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue