mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
Add batch_size param to Weaviate vector store (#9890)
cc @mcantillon21 @hsm207 @cs0lar
This commit is contained in:
parent
720f6dbaac
commit
76dd7480e6
@ -1,17 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
|
||||
import os
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import weaviate
|
||||
|
||||
|
||||
def _default_schema(index_name: str) -> Dict:
|
||||
return {
|
||||
@ -25,21 +37,11 @@ def _default_schema(index_name: str) -> Dict:
|
||||
}
|
||||
|
||||
|
||||
def _create_weaviate_client(**kwargs: Any) -> Any:
|
||||
client = kwargs.get("client")
|
||||
if client is not None:
|
||||
return client
|
||||
|
||||
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
|
||||
|
||||
try:
|
||||
# the weaviate api key param should not be mandatory
|
||||
weaviate_api_key = get_from_dict_or_env(
|
||||
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
|
||||
)
|
||||
except ValueError:
|
||||
weaviate_api_key = None
|
||||
|
||||
def _create_weaviate_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> weaviate.Client:
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
@ -47,15 +49,10 @@ def _create_weaviate_client(**kwargs: Any) -> Any:
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`"
|
||||
)
|
||||
|
||||
auth = (
|
||||
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
|
||||
if weaviate_api_key is not None
|
||||
else None
|
||||
)
|
||||
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
|
||||
|
||||
return client
|
||||
url = url or os.environ.get("WEAVIATE_URL")
|
||||
api_key = api_key or os.environ.get("WEAVIATE_API_KEY")
|
||||
auth = weaviate.auth.AuthApiKey(api_key=api_key) if api_key else None
|
||||
return weaviate.Client(url=url, auth_client_secret=auth, **kwargs)
|
||||
|
||||
|
||||
def _default_score_normalizer(val: float) -> float:
|
||||
@ -78,6 +75,7 @@ class Weaviate(VectorStore):
|
||||
|
||||
import weaviate
|
||||
from langchain.vectorstores import Weaviate
|
||||
|
||||
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
|
||||
weaviate = Weaviate(client, index_name, text_key)
|
||||
|
||||
@ -375,10 +373,21 @@ class Weaviate(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Weaviate],
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
*,
|
||||
client: Optional[weaviate.Client] = None,
|
||||
weaviate_url: Optional[str] = None,
|
||||
weaviate_api_key: Optional[str] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
index_name: Optional[str] = None,
|
||||
text_key: str = "text",
|
||||
by_text: bool = False,
|
||||
relevance_score_fn: Optional[
|
||||
Callable[[float], float]
|
||||
] = _default_score_normalizer,
|
||||
**kwargs: Any,
|
||||
) -> Weaviate:
|
||||
"""Construct Weaviate wrapper from raw documents.
|
||||
@ -390,11 +399,34 @@ class Weaviate(VectorStore):
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Args:
|
||||
texts: Texts to add to vector store.
|
||||
embedding: Text embedding model to use.
|
||||
metadatas: Metadata associated with each text.
|
||||
client: weaviate.Client to use.
|
||||
weaviate_url: The Weaviate URL. If using Weaviate Cloud Services get it
|
||||
from the ``Details`` tab. Can be passed in as a named param or by
|
||||
setting the environment variable ``WEAVIATE_URL``. Should not be
|
||||
specified if client is provided.
|
||||
weaviate_api_key: The Weaviate API key. If enabled and using Weaviate Cloud
|
||||
Services, get it from ``Details`` tab. Can be passed in as a named param
|
||||
or by setting the environment variable ``WEAVIATE_API_KEY``. Should
|
||||
not be specified if client is provided.
|
||||
batch_size: Size of batch operations.
|
||||
index_name: Index name.
|
||||
text_key: Key to use for uploading/retrieving text to/from vectorstore.
|
||||
by_text: Whether to search by text or by embedding.
|
||||
relevance_score_fn: Function for converting whatever distance function the
|
||||
vector store uses to a relevance score, which is a normalized similarity
|
||||
score (0 means dissimilar, 1 means similar).
|
||||
**kwargs: Additional named parameters to pass to ``Weaviate.__init__()``.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores.weaviate import Weaviate
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Weaviate
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
weaviate = Weaviate.from_texts(
|
||||
texts,
|
||||
@ -403,20 +435,30 @@ class Weaviate(VectorStore):
|
||||
)
|
||||
"""
|
||||
|
||||
client = _create_weaviate_client(**kwargs)
|
||||
|
||||
try:
|
||||
from weaviate.util import get_valid_uuid
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`"
|
||||
) from e
|
||||
|
||||
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
|
||||
embeddings = embedding.embed_documents(texts) if embedding else None
|
||||
text_key = "text"
|
||||
client = client or _create_weaviate_client(
|
||||
url=weaviate_url,
|
||||
api_key=weaviate_api_key,
|
||||
)
|
||||
if batch_size:
|
||||
client.batch.configure(batch_size=batch_size)
|
||||
|
||||
index_name = index_name or f"LangChain_{uuid4().hex}"
|
||||
schema = _default_schema(index_name)
|
||||
attributes = list(metadatas[0].keys()) if metadatas else None
|
||||
|
||||
# check whether the index already exists
|
||||
if not client.schema.contains(schema):
|
||||
client.schema.create_class(schema)
|
||||
|
||||
embeddings = embedding.embed_documents(texts) if embedding else None
|
||||
attributes = list(metadatas[0].keys()) if metadatas else None
|
||||
|
||||
with client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {
|
||||
@ -449,9 +491,6 @@ class Weaviate(VectorStore):
|
||||
|
||||
batch.flush()
|
||||
|
||||
relevance_score_fn = kwargs.get("relevance_score_fn")
|
||||
by_text: bool = kwargs.get("by_text", False)
|
||||
|
||||
return cls(
|
||||
client,
|
||||
index_name,
|
||||
@ -460,6 +499,7 @@ class Weaviate(VectorStore):
|
||||
attributes=attributes,
|
||||
relevance_score_fn=relevance_score_fn,
|
||||
by_text=by_text,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user