diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index de1572410c..cbc846f716 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -65,7 +65,7 @@ class ChromaVector(BaseVector): self._client.get_or_create_collection(collection_name) redis_client.set(collection_exist_cache_key, 1, ex=3600) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] @@ -73,6 +73,7 @@ class ChromaVector(BaseVector): collection = self._client.get_or_create_collection(self._collection_name) # FIXME: chromadb using numpy array, fix the type error later collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore + return uuids def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 91bb71bfa6..8e8120fc10 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector): logger.warning("Failed to create inverted index: %s", e) # Continue without inverted index - full-text search will fall back to LIKE - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: """Add documents with embeddings to the collection.""" if not documents: - return + return [] batch_size = self._config.batch_size total_batches = (len(documents) + batch_size - 1) // batch_size + added_ids = [] for i in range(0, len(documents), batch_size): batch_docs = documents[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] + batch_doc_ids = [] + for doc in batch_docs: + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} + batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))) + added_ids.extend(batch_doc_ids) # Execute batch insert through write queue - self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) + self._execute_write( + self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches + ) + + return added_ids def _insert_batch( self, batch_docs: list[Document], batch_embeddings: list[list[float]], + batch_doc_ids: list[str], batch_index: int, batch_size: int, total_batches: int, @@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector): data_rows = [] vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 - for doc, embedding in zip(batch_docs, batch_embeddings): + for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids): # Optimized: minimal checks for common case, fallback for edge cases - metadata = doc.metadata or {} - - if not isinstance(metadata, dict): - metadata = {} - - doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} # Fast path for JSON serialization try: