mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-21 18:37:04 +00:00
Weaviate attributes and error handling (#2800)
This commit is contained in:
parent
0e763677e4
commit
1c7fb31bba
@ -1,7 +1,7 @@
|
|||||||
"""Wrapper around weaviate vector database."""
|
"""Wrapper around weaviate vector database."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import Extra
|
from pydantic import Extra
|
||||||
@ -18,6 +18,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
text_key: str,
|
text_key: str,
|
||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
attributes: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import weaviate
|
import weaviate
|
||||||
@ -36,6 +37,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
self._index_name = index_name
|
self._index_name = index_name
|
||||||
self._text_key = text_key
|
self._text_key = text_key
|
||||||
self._query_attrs = [self._text_key]
|
self._query_attrs = [self._text_key]
|
||||||
|
if attributes is not None:
|
||||||
|
self._query_attrs.extend(attributes)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -67,6 +70,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
result = (
|
result = (
|
||||||
query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do()
|
query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do()
|
||||||
)
|
)
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
|
@ -83,6 +83,8 @@ class Weaviate(VectorStore):
|
|||||||
content["certainty"] = kwargs.get("search_distance")
|
content["certainty"] = kwargs.get("search_distance")
|
||||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
result = query_obj.with_near_text(content).with_limit(k).do()
|
result = query_obj.with_near_text(content).with_limit(k).do()
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
docs = []
|
docs = []
|
||||||
for res in result["data"]["Get"][self._index_name]:
|
for res in result["data"]["Get"][self._index_name]:
|
||||||
text = res.pop(self._text_key)
|
text = res.pop(self._text_key)
|
||||||
@ -96,6 +98,8 @@ class Weaviate(VectorStore):
|
|||||||
vector = {"vector": embedding}
|
vector = {"vector": embedding}
|
||||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
result = query_obj.with_near_vector(vector).with_limit(k).do()
|
result = query_obj.with_near_vector(vector).with_limit(k).do()
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
docs = []
|
docs = []
|
||||||
for res in result["data"]["Get"][self._index_name]:
|
for res in result["data"]["Get"][self._index_name]:
|
||||||
text = res.pop(self._text_key)
|
text = res.pop(self._text_key)
|
||||||
|
Loading…
Reference in New Issue
Block a user