fix: max_marginal_relevance_search and docs in Dingo (#9244)

This commit is contained in:
Hech 2023-08-15 16:06:06 +08:00 committed by GitHub
parent 664ff28cba
commit 4b505060bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 13 deletions

View File

@ -23,7 +23,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install dingodb" "!pip install dingodb\n",
"or install latest:\n",
"!pip install git+https://git@github.com/dingodb/pydingo.git"
] ]
}, },
{ {
@ -107,7 +109,7 @@
"dingo_client = DingoDB(user=\"\", password=\"\", host=[\"127.0.0.1:13000\"])\n", "dingo_client = DingoDB(user=\"\", password=\"\", host=[\"127.0.0.1:13000\"])\n",
"# First, check if our index already exists. If it doesn't, we create it\n", "# First, check if our index already exists. If it doesn't, we create it\n",
"if index_name not in dingo_client.get_index():\n", "if index_name not in dingo_client.get_index():\n",
" # we create a new index\n", " # we create a new index, modify to your own\n",
" dingo_client.create_index(\n", " dingo_client.create_index(\n",
" index_name=index_name,\n", " index_name=index_name,\n",
" dimension=1536,\n", " dimension=1536,\n",
@ -150,7 +152,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(docs[0][1])" "print(docs[0].page_content)"
] ]
}, },
{ {
@ -170,9 +172,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"vectorstore = Dingo(client, embeddings.embed_query, \"text\")\n", "vectorstore = Dingo(embeddings, \"text\", client=dingo_client, index_name=index_name)\n",
"\n", "\n",
"vectorstore.add_texts(\"More text!\")" "vectorstore.add_texts([\"More text!\"])"
] ]
}, },
{ {

View File

@ -112,9 +112,11 @@ class Dingo(VectorStore):
# upsert to Dingo # upsert to Dingo
for i in range(0, len(list(texts)), batch_size): for i in range(0, len(list(texts)), batch_size):
j = i + batch_size j = i + batch_size
self._client.vector_add( add_res = self._client.vector_add(
self._index_name, metadatas_list[i:j], embeds[i:j], ids[i:j] self._index_name, metadatas_list[i:j], embeds[i:j], ids[i:j]
) )
if not add_res:
raise Exception("vector add fail")
return ids return ids
@ -205,20 +207,26 @@ class Dingo(VectorStore):
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
results = self._client.vector_search( results = self._client.vector_search(
self._index_name, [embedding], search_params, k self._index_name, [embedding], search_params=search_params, top_k=k
) )
mmr_selected = maximal_marginal_relevance( mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32), np.array([embedding], dtype=np.float32),
[item["floatValues"] for item in results[0]["vectorWithDistances"]], [
item["vector"]["floatValues"]
for item in results[0]["vectorWithDistances"]
],
k=k, k=k,
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
) )
selected = [ selected = []
results[0]["vectorWithDistances"][i]["metaData"] for i in mmr_selected for i in mmr_selected:
] meta_data = {}
for k, v in results[0]["vectorWithDistances"][i]["scalarData"].items():
meta_data.update({str(k): v["fields"][0]["data"]})
selected.append(meta_data)
return [ return [
Document(page_content=metadata.pop((self._text_key)), metadata=metadata) Document(page_content=metadata.pop(self._text_key), metadata=metadata)
for metadata in selected for metadata in selected
] ]
@ -328,9 +336,11 @@ class Dingo(VectorStore):
# upsert to Dingo # upsert to Dingo
for i in range(0, len(list(texts)), batch_size): for i in range(0, len(list(texts)), batch_size):
j = i + batch_size j = i + batch_size
dingo_client.vector_add( add_res = dingo_client.vector_add(
index_name, metadatas_list[i:j], embeds[i:j], ids[i:j] index_name, metadatas_list[i:j], embeds[i:j], ids[i:j]
) )
if not add_res:
raise Exception("vector add fail")
return cls(embedding, text_key, client=dingo_client, index_name=index_name) return cls(embedding, text_key, client=dingo_client, index_name=index_name)
def delete( def delete(