This commit is contained in:
Ethan T. 2026-03-24 13:04:48 +08:00 committed by GitHub
commit a611729766
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 35 additions and 0 deletions

View File

@ -109,6 +109,7 @@ class WeaviateVector(BaseVector):
attributes: List of metadata attributes to store
"""
super().__init__(collection_name)
self._config = config
self._client = self._init_client(config)
self._attributes = attributes
@ -165,6 +166,32 @@ class WeaviateVector(BaseVector):
_weaviate_client = client
return client
def _ensure_connected(self) -> None:
"""
Ensures the Weaviate client connection is active.
The Weaviate client connection can time out during idle periods (e.g. hours
between knowledge base creation and recall testing). This method checks the
connection state and reconnects if necessary, preventing
"WeaviateClient is closed" errors.
"""
if self._client.is_connected():
return
try:
self._client.connect()
logger.info("Reconnected to Weaviate (collection=%s)", self._collection_name)
except Exception:
logger.warning(
"Failed to reconnect existing Weaviate client, creating a new one (collection=%s)",
self._collection_name,
)
try:
self._client.close()
except Exception:
pass
self._client = self._init_client(self._config)
def get_type(self) -> str:
"""Returns the vector database type identifier."""
return VectorType.WEAVIATE
@ -201,6 +228,7 @@ class WeaviateVector(BaseVector):
Uses Redis locking to prevent concurrent creation attempts.
"""
self._ensure_connected()
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
cache_key = f"vector_indexing_{self._collection_name}"
@ -286,6 +314,7 @@ class WeaviateVector(BaseVector):
Batches insertions for efficiency and returns the list of inserted object IDs.
"""
self._ensure_connected()
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
@ -332,6 +361,7 @@ class WeaviateVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Deletes all objects matching a specific metadata field value."""
self._ensure_connected()
if not self._client.collections.exists(self._collection_name):
return
@ -340,11 +370,13 @@ class WeaviateVector(BaseVector):
def delete(self):
"""Deletes the entire collection from Weaviate."""
self._ensure_connected()
if self._client.collections.exists(self._collection_name):
self._client.collections.delete(self._collection_name)
def text_exists(self, id: str) -> bool:
"""Checks if a document with the given doc_id exists in the collection."""
self._ensure_connected()
if not self._client.collections.exists(self._collection_name):
return False
@ -363,6 +395,7 @@ class WeaviateVector(BaseVector):
Silently ignores 404 errors for non-existent IDs.
"""
self._ensure_connected()
if not self._client.collections.exists(self._collection_name):
return
@ -382,6 +415,7 @@ class WeaviateVector(BaseVector):
Filters by document IDs if provided and applies score threshold.
Returns documents sorted by relevance score.
"""
self._ensure_connected()
if not self._client.collections.exists(self._collection_name):
return []
@ -429,6 +463,7 @@ class WeaviateVector(BaseVector):
Filters by document IDs if provided and returns matching documents with vectors.
"""
self._ensure_connected()
if not self._client.collections.exists(self._collection_name):
return []