mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
Add batch_size param to Weaviate vector store (#9890)
cc @mcantillon21 @hsm207 @cs0lar
This commit is contained in:
parent
720f6dbaac
commit
76dd7480e6
@ -1,17 +1,29 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
|
import os
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.schema.embeddings import Embeddings
|
from langchain.schema.embeddings import Embeddings
|
||||||
from langchain.utils import get_from_dict_or_env
|
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import weaviate
|
||||||
|
|
||||||
|
|
||||||
def _default_schema(index_name: str) -> Dict:
|
def _default_schema(index_name: str) -> Dict:
|
||||||
return {
|
return {
|
||||||
@ -25,21 +37,11 @@ def _default_schema(index_name: str) -> Dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _create_weaviate_client(**kwargs: Any) -> Any:
|
def _create_weaviate_client(
|
||||||
client = kwargs.get("client")
|
url: Optional[str] = None,
|
||||||
if client is not None:
|
api_key: Optional[str] = None,
|
||||||
return client
|
**kwargs: Any,
|
||||||
|
) -> weaviate.Client:
|
||||||
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# the weaviate api key param should not be mandatory
|
|
||||||
weaviate_api_key = get_from_dict_or_env(
|
|
||||||
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
weaviate_api_key = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import weaviate
|
import weaviate
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -47,15 +49,10 @@ def _create_weaviate_client(**kwargs: Any) -> Any:
|
|||||||
"Could not import weaviate python package. "
|
"Could not import weaviate python package. "
|
||||||
"Please install it with `pip install weaviate-client`"
|
"Please install it with `pip install weaviate-client`"
|
||||||
)
|
)
|
||||||
|
url = url or os.environ.get("WEAVIATE_URL")
|
||||||
auth = (
|
api_key = api_key or os.environ.get("WEAVIATE_API_KEY")
|
||||||
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
|
auth = weaviate.auth.AuthApiKey(api_key=api_key) if api_key else None
|
||||||
if weaviate_api_key is not None
|
return weaviate.Client(url=url, auth_client_secret=auth, **kwargs)
|
||||||
else None
|
|
||||||
)
|
|
||||||
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def _default_score_normalizer(val: float) -> float:
|
def _default_score_normalizer(val: float) -> float:
|
||||||
@ -78,6 +75,7 @@ class Weaviate(VectorStore):
|
|||||||
|
|
||||||
import weaviate
|
import weaviate
|
||||||
from langchain.vectorstores import Weaviate
|
from langchain.vectorstores import Weaviate
|
||||||
|
|
||||||
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
|
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
|
||||||
weaviate = Weaviate(client, index_name, text_key)
|
weaviate = Weaviate(client, index_name, text_key)
|
||||||
|
|
||||||
@ -375,10 +373,21 @@ class Weaviate(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls: Type[Weaviate],
|
cls,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
*,
|
||||||
|
client: Optional[weaviate.Client] = None,
|
||||||
|
weaviate_url: Optional[str] = None,
|
||||||
|
weaviate_api_key: Optional[str] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
index_name: Optional[str] = None,
|
||||||
|
text_key: str = "text",
|
||||||
|
by_text: bool = False,
|
||||||
|
relevance_score_fn: Optional[
|
||||||
|
Callable[[float], float]
|
||||||
|
] = _default_score_normalizer,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Weaviate:
|
) -> Weaviate:
|
||||||
"""Construct Weaviate wrapper from raw documents.
|
"""Construct Weaviate wrapper from raw documents.
|
||||||
@ -390,11 +399,34 @@ class Weaviate(VectorStore):
|
|||||||
|
|
||||||
This is intended to be a quick way to get started.
|
This is intended to be a quick way to get started.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Texts to add to vector store.
|
||||||
|
embedding: Text embedding model to use.
|
||||||
|
metadatas: Metadata associated with each text.
|
||||||
|
client: weaviate.Client to use.
|
||||||
|
weaviate_url: The Weaviate URL. If using Weaviate Cloud Services get it
|
||||||
|
from the ``Details`` tab. Can be passed in as a named param or by
|
||||||
|
setting the environment variable ``WEAVIATE_URL``. Should not be
|
||||||
|
specified if client is provided.
|
||||||
|
weaviate_api_key: The Weaviate API key. If enabled and using Weaviate Cloud
|
||||||
|
Services, get it from ``Details`` tab. Can be passed in as a named param
|
||||||
|
or by setting the environment variable ``WEAVIATE_API_KEY``. Should
|
||||||
|
not be specified if client is provided.
|
||||||
|
batch_size: Size of batch operations.
|
||||||
|
index_name: Index name.
|
||||||
|
text_key: Key to use for uploading/retrieving text to/from vectorstore.
|
||||||
|
by_text: Whether to search by text or by embedding.
|
||||||
|
relevance_score_fn: Function for converting whatever distance function the
|
||||||
|
vector store uses to a relevance score, which is a normalized similarity
|
||||||
|
score (0 means dissimilar, 1 means similar).
|
||||||
|
**kwargs: Additional named parameters to pass to ``Weaviate.__init__()``.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.vectorstores.weaviate import Weaviate
|
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.vectorstores import Weaviate
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
weaviate = Weaviate.from_texts(
|
weaviate = Weaviate.from_texts(
|
||||||
texts,
|
texts,
|
||||||
@ -403,20 +435,30 @@ class Weaviate(VectorStore):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client = _create_weaviate_client(**kwargs)
|
try:
|
||||||
|
from weaviate.util import get_valid_uuid
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import weaviate python package. "
|
||||||
|
"Please install it with `pip install weaviate-client`"
|
||||||
|
) from e
|
||||||
|
|
||||||
from weaviate.util import get_valid_uuid
|
client = client or _create_weaviate_client(
|
||||||
|
url=weaviate_url,
|
||||||
|
api_key=weaviate_api_key,
|
||||||
|
)
|
||||||
|
if batch_size:
|
||||||
|
client.batch.configure(batch_size=batch_size)
|
||||||
|
|
||||||
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
|
index_name = index_name or f"LangChain_{uuid4().hex}"
|
||||||
embeddings = embedding.embed_documents(texts) if embedding else None
|
|
||||||
text_key = "text"
|
|
||||||
schema = _default_schema(index_name)
|
schema = _default_schema(index_name)
|
||||||
attributes = list(metadatas[0].keys()) if metadatas else None
|
|
||||||
|
|
||||||
# check whether the index already exists
|
# check whether the index already exists
|
||||||
if not client.schema.contains(schema):
|
if not client.schema.contains(schema):
|
||||||
client.schema.create_class(schema)
|
client.schema.create_class(schema)
|
||||||
|
|
||||||
|
embeddings = embedding.embed_documents(texts) if embedding else None
|
||||||
|
attributes = list(metadatas[0].keys()) if metadatas else None
|
||||||
|
|
||||||
with client.batch as batch:
|
with client.batch as batch:
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
data_properties = {
|
data_properties = {
|
||||||
@ -449,9 +491,6 @@ class Weaviate(VectorStore):
|
|||||||
|
|
||||||
batch.flush()
|
batch.flush()
|
||||||
|
|
||||||
relevance_score_fn = kwargs.get("relevance_score_fn")
|
|
||||||
by_text: bool = kwargs.get("by_text", False)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
client,
|
client,
|
||||||
index_name,
|
index_name,
|
||||||
@ -460,6 +499,7 @@ class Weaviate(VectorStore):
|
|||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
relevance_score_fn=relevance_score_fn,
|
relevance_score_fn=relevance_score_fn,
|
||||||
by_text=by_text,
|
by_text=by_text,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user