mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 09:30:15 +00:00
Weaviate attributes and error handling (#2800)
This commit is contained in:
parent
0e763677e4
commit
1c7fb31bba
@ -1,7 +1,7 @@
|
||||
"""Wrapper around weaviate vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Extra
|
||||
@ -18,6 +18,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
text_key: str,
|
||||
alpha: float = 0.5,
|
||||
k: int = 4,
|
||||
attributes: Optional[List[str]] = None,
|
||||
):
|
||||
try:
|
||||
import weaviate
|
||||
@ -36,6 +37,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
self._index_name = index_name
|
||||
self._text_key = text_key
|
||||
self._query_attrs = [self._text_key]
|
||||
if attributes is not None:
|
||||
self._query_attrs.extend(attributes)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -67,6 +70,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
result = (
|
||||
query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do()
|
||||
)
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
docs = []
|
||||
|
||||
|
@ -83,6 +83,8 @@ class Weaviate(VectorStore):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
result = query_obj.with_near_text(content).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
@ -96,6 +98,8 @@ class Weaviate(VectorStore):
|
||||
vector = {"vector": embedding}
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
result = query_obj.with_near_vector(vector).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
|
Loading…
Reference in New Issue
Block a user