Harrison/optional ids opensearch (#6684)

Co-authored-by: taekimsmar <66041442+taekimsmar@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-06-24 09:19:57 -07:00 committed by GitHub
parent 2518e6c95b
commit c289cc891a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -77,6 +77,7 @@ def _bulk_ingest_embeddings(
embeddings: List[List[float]], embeddings: List[List[float]],
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
text_field: str = "text", text_field: str = "text",
mapping: Optional[Dict] = None, mapping: Optional[Dict] = None,
@ -88,7 +89,7 @@ def _bulk_ingest_embeddings(
bulk = _import_bulk() bulk = _import_bulk()
not_found_error = _import_not_found_error() not_found_error = _import_not_found_error()
requests = [] requests = []
ids = [] return_ids = []
mapping = mapping mapping = mapping
try: try:
@ -98,7 +99,7 @@ def _bulk_ingest_embeddings(
for i, text in enumerate(texts): for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {} metadata = metadatas[i] if metadatas else {}
_id = str(uuid.uuid4()) _id = ids[i] if ids else str(uuid.uuid4())
request = { request = {
"_op_type": "index", "_op_type": "index",
"_index": index_name, "_index": index_name,
@ -108,10 +109,10 @@ def _bulk_ingest_embeddings(
"_id": _id, "_id": _id,
} }
requests.append(request) requests.append(request)
ids.append(_id) return_ids.append(_id)
bulk(client, requests) bulk(client, requests)
client.indices.refresh(index=index_name) client.indices.refresh(index=index_name)
return ids return return_ids
def _default_scripting_text_mapping( def _default_scripting_text_mapping(
@ -318,6 +319,7 @@ class OpenSearchVectorSearch(VectorStore):
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
bulk_size: int = 500, bulk_size: int = 500,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
@ -326,6 +328,7 @@ class OpenSearchVectorSearch(VectorStore):
Args: Args:
texts: Iterable of strings to add to the vectorstore. texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts. metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of ids to associate with the texts.
bulk_size: Bulk API request count; Default: 500 bulk_size: Bulk API request count; Default: 500
Returns: Returns:
@ -358,10 +361,11 @@ class OpenSearchVectorSearch(VectorStore):
self.index_name, self.index_name,
embeddings, embeddings,
texts, texts,
metadatas, metadatas=metadatas,
vector_field, ids=ids,
text_field, vector_field=vector_field,
mapping, text_field=text_field,
mapping=mapping,
) )
def similarity_search( def similarity_search(
@ -679,9 +683,9 @@ class OpenSearchVectorSearch(VectorStore):
index_name, index_name,
embeddings, embeddings,
texts, texts,
metadatas, metadatas=metadatas,
vector_field, vector_field=vector_field,
text_field, text_field=text_field,
mapping, mapping=mapping,
) )
return cls(opensearch_url, index_name, embedding, **kwargs) return cls(opensearch_url, index_name, embedding, **kwargs)