mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
#11763 --------- Co-authored-by: TranswarpHippo <hippo.0.assistant@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
342d6c7ab6
commit
92bf40a921
431
docs/docs/integrations/vectorstores/hippo.ipynb
Normal file
431
docs/docs/integrations/vectorstores/hippo.ipynb
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## Hippo\n",
|
||||||
|
"\n",
|
||||||
|
">[Hippo](https://www.transwarp.cn/starwarp) Please visit our official website for how to run a Hippo instance and\n",
|
||||||
|
"how to use functionality related to the Hippo vector database\n",
|
||||||
|
"\n",
|
||||||
|
"## Getting Started\n",
|
||||||
|
"\n",
|
||||||
|
"The only prerequisite here is an API key from the OpenAI website. Make sure you have already started a Hippo instance."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "357f24224a8e818f"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## Installing Dependencies\n",
|
||||||
|
"\n",
|
||||||
|
"Initially, we require the installation of certain dependencies, such as OpenAI, Langchain, and Hippo-API. Please note, you should install the appropriate versions tailored to your environment."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "a92d2ce26df7ac4c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Requirement already satisfied: hippo-api==1.1.0.rc3 in /Users/daochengzhang/miniforge3/envs/py310/lib/python3.10/site-packages (1.1.0rc3)\r\n",
|
||||||
|
"Requirement already satisfied: pyyaml>=6.0 in /Users/daochengzhang/miniforge3/envs/py310/lib/python3.10/site-packages (from hippo-api==1.1.0.rc3) (6.0.1)\r\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!pip install langchain tiktoken openai\n",
|
||||||
|
"!pip install hippo-api==1.1.0.rc3"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:47:54.718488Z",
|
||||||
|
"start_time": "2023-10-30T06:47:53.563129Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "13b1d1ae153ff434"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"Note: Python version needs to be >=3.8.\n",
|
||||||
|
"\n",
|
||||||
|
"## Best Practice\n",
|
||||||
|
"### Importing Dependency Packages"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "554081137df2c252"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import AzureChatOpenAI, ChatOpenAI\n",
|
||||||
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
|
"from langchain.vectorstores.hippo import Hippo\n",
|
||||||
|
"import os"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:47:56.003409Z",
|
||||||
|
"start_time": "2023-10-30T06:47:55.998839Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "5ff3296ce812aeb8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Loading Knowledge Documents"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "dad255dae8aea755"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = \"YOUR OPENAI KEY\"\n",
|
||||||
|
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
|
||||||
|
"documents = loader.load()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:47:59.027869Z",
|
||||||
|
"start_time": "2023-10-30T06:47:59.023934Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "f02d66a7fd653dc1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Segmenting the Knowledge Document\n",
|
||||||
|
"\n",
|
||||||
|
"Here, we use Langchain's CharacterTextSplitter for segmentation. The delimiter is a period. After segmentation, the text segment does not exceed 1000 characters, and the number of repeated characters is 0."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "e9b93c330f1c6160"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
|
||||||
|
"docs = text_splitter.split_documents(documents)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:48:00.279351Z",
|
||||||
|
"start_time": "2023-10-30T06:48:00.275763Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "fe6b43175318331f"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Declaring the Embedding Model\n",
|
||||||
|
"Below, we create the OpenAI or Azure embedding model using the OpenAIEmbeddings method from Langchain."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "eefe28c7c993ffdf"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# openai\n",
|
||||||
|
"embeddings = OpenAIEmbeddings()\n",
|
||||||
|
"# azure\n",
|
||||||
|
"# embeddings = OpenAIEmbeddings(\n",
|
||||||
|
"# openai_api_type=\"azure\",\n",
|
||||||
|
"# openai_api_base=\"x x x\",\n",
|
||||||
|
"# openai_api_version=\"x x x\",\n",
|
||||||
|
"# model=\"x x x\",\n",
|
||||||
|
"# deployment=\"x x x\",\n",
|
||||||
|
"# openai_api_key=\"x x x\"\n",
|
||||||
|
"# )"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:48:11.686166Z",
|
||||||
|
"start_time": "2023-10-30T06:48:11.664355Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "8619f16b9f7355ea"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Declaring Hippo Client"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "e60235602ed91d3c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"HIPPO_CONNECTION = {\"host\": \"IP\", \"port\": \"PORT\"}"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:48:48.594298Z",
|
||||||
|
"start_time": "2023-10-30T06:48:48.585267Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "c666b70dcab78129"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Storing the Document"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "43ee6dbd765c3172"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"input...\n",
|
||||||
|
"success\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"input...\")\n",
|
||||||
|
"# insert docs\n",
|
||||||
|
"vector_store = Hippo.from_documents(\n",
|
||||||
|
" docs,\n",
|
||||||
|
" embedding=embeddings,\n",
|
||||||
|
" table_name=\"langchain_test\",\n",
|
||||||
|
" connection_args=HIPPO_CONNECTION,\n",
|
||||||
|
")\n",
|
||||||
|
"print(\"success\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:51:12.661741Z",
|
||||||
|
"start_time": "2023-10-30T06:51:06.257156Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "79372c869844bdc9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Conducting Knowledge-based Question and Answer\n",
|
||||||
|
"#### Creating a Large Language Question-Answering Model\n",
|
||||||
|
"Below, we create the OpenAI or Azure large language question-answering model respectively using the AzureChatOpenAI and ChatOpenAI methods from Langchain."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "89077cc9763d5dd0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# llm = AzureChatOpenAI(\n",
|
||||||
|
"# openai_api_base=\"x x x\",\n",
|
||||||
|
"# openai_api_version=\"xxx\",\n",
|
||||||
|
"# deployment_name=\"xxx\",\n",
|
||||||
|
"# openai_api_key=\"xxx\",\n",
|
||||||
|
"# openai_api_type=\"azure\"\n",
|
||||||
|
"# )\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(openai_api_key=\"YOUR OPENAI KEY\", model_name=\"gpt-3.5-turbo-16k\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:51:28.329351Z",
|
||||||
|
"start_time": "2023-10-30T06:51:28.318713Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "c9f2c42e9884f628"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Acquiring Related Knowledge Based on the Question:"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "a4c5d73016a9db0c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"Please introduce COVID-19\"\n",
|
||||||
|
"# query = \"Please introduce Hippo Core Architecture\"\n",
|
||||||
|
"# query = \"What operations does the Hippo Vector Database support for vector data?\"\n",
|
||||||
|
"# query = \"Does Hippo use hardware acceleration technology? Briefly introduce hardware acceleration technology.\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Retrieve similar content from the knowledge base,fetch the top two most similar texts.\n",
|
||||||
|
"res = vector_store.similarity_search(query, 2)\n",
|
||||||
|
"content_list = [item.page_content for item in res]\n",
|
||||||
|
"text = \"\".join(content_list)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:51:33.195634Z",
|
||||||
|
"start_time": "2023-10-30T06:51:32.196493Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "8656e80519da1f97"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Constructing a Prompt Template"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "e5adbaaa7086d1ae"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prompt = f\"\"\"\n",
|
||||||
|
"Please use the content of the following [Article] to answer my question. If you don't know, please say you don't know, and the answer should be concise.\"\n",
|
||||||
|
"[Article]:{text}\n",
|
||||||
|
"Please answer this question in conjunction with the above article:{query}\n",
|
||||||
|
"\"\"\""
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:51:35.649376Z",
|
||||||
|
"start_time": "2023-10-30T06:51:35.645763Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "b915d3001a2741c1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Waiting for the Large Language Model to Generate an Answer"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "b36b6a9adbec8a82"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"response_with_hippo:COVID-19 is a virus that has impacted every aspect of our lives for over two years. It is a highly contagious and mutates easily, requiring us to remain vigilant in combating its spread. However, due to progress made and the resilience of individuals, we are now able to move forward safely and return to more normal routines.\n",
|
||||||
|
"==========================================\n",
|
||||||
|
"response_without_hippo:COVID-19 is a contagious respiratory illness caused by the novel coronavirus SARS-CoV-2. It was first identified in December 2019 in Wuhan, China and has since spread globally, leading to a pandemic. The virus primarily spreads through respiratory droplets when an infected person coughs, sneezes, talks, or breathes, and can also spread by touching contaminated surfaces and then touching the face. COVID-19 symptoms include fever, cough, shortness of breath, fatigue, muscle or body aches, sore throat, loss of taste or smell, headache, and in severe cases, pneumonia and organ failure. While most people experience mild to moderate symptoms, it can lead to severe illness and even death, particularly among older adults and those with underlying health conditions. To combat the spread of the virus, various preventive measures have been implemented globally, including social distancing, wearing face masks, practicing good hand hygiene, and vaccination efforts.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response_with_hippo = llm.predict(prompt)\n",
|
||||||
|
"print(f\"response_with_hippo:{response_with_hippo}\")\n",
|
||||||
|
"response = llm.predict(query)\n",
|
||||||
|
"print(\"==========================================\")\n",
|
||||||
|
"print(f\"response_without_hippo:{response}\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-10-30T06:52:17.967885Z",
|
||||||
|
"start_time": "2023-10-30T06:51:37.692819Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "58eb5d2396321001"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"start_time": "2023-10-30T06:42:42.172639Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "b2b7ce4e1850ecf1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
677
libs/langchain/langchain/vectorstores/hippo.py
Normal file
677
libs/langchain/langchain/vectorstores/hippo.py
Normal file
@ -0,0 +1,677 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.schema.embeddings import Embeddings
|
||||||
|
from langchain.schema.vectorstore import VectorStore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transwarp_hippo_api.hippo_client import HippoClient
|
||||||
|
|
||||||
|
# Default connection
|
||||||
|
DEFAULT_HIPPO_CONNECTION = {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "7788",
|
||||||
|
"username": "admin",
|
||||||
|
"password": "admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Hippo(VectorStore):
|
||||||
|
"""`Hippo` vector store.
|
||||||
|
|
||||||
|
You need to install `hippo-api` and run Hippo.
|
||||||
|
|
||||||
|
Please visit our official website for how to run a Hippo instance:
|
||||||
|
https://www.transwarp.cn/starwarp
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_function (Embeddings): Function used to embed the text.
|
||||||
|
table_name (str): Which Hippo table to use. Defaults to
|
||||||
|
"test".
|
||||||
|
database_name (str): Which Hippo database to use. Defaults to
|
||||||
|
"default".
|
||||||
|
number_of_shards (int): The number of shards for the Hippo table.Defaults to
|
||||||
|
1.
|
||||||
|
number_of_replicas (int): The number of replicas for the Hippo table.Defaults to
|
||||||
|
1.
|
||||||
|
connection_args (Optional[dict[str, any]]): The connection args used for
|
||||||
|
this class comes in the form of a dict.
|
||||||
|
index_params (Optional[dict]): Which index params to use. Defaults to
|
||||||
|
IVF_FLAT.
|
||||||
|
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||||||
|
to False.
|
||||||
|
primary_field (str): Name of the primary key field. Defaults to "pk".
|
||||||
|
text_field (str): Name of the text field. Defaults to "text".
|
||||||
|
vector_field (str): Name of the vector field. Defaults to "vector".
|
||||||
|
|
||||||
|
The connection args used for this class comes in the form of a dict,
|
||||||
|
here are a few of the options:
|
||||||
|
host (str): The host of Hippo instance. Default at "localhost".
|
||||||
|
port (str/int): The port of Hippo instance. Default at 7788.
|
||||||
|
user (str): Use which user to connect to Hippo instance. If user and
|
||||||
|
password are provided, we will add related header in every RPC call.
|
||||||
|
password (str): Required when user is provided. The password
|
||||||
|
corresponding to the user.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.vectorstores import Hippo
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embedding = OpenAIEmbeddings()
|
||||||
|
# Connect to a hippo instance on localhost
|
||||||
|
vector_store = Hippo.from_documents(
|
||||||
|
docs,
|
||||||
|
embedding=embeddings,
|
||||||
|
table_name="langchain_test",
|
||||||
|
connection_args=HIPPO_CONNECTION
|
||||||
|
)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the hippo-api python package is not installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_function: Embeddings,
|
||||||
|
table_name: str = "test",
|
||||||
|
database_name: str = "default",
|
||||||
|
number_of_shards: int = 1,
|
||||||
|
number_of_replicas: int = 1,
|
||||||
|
connection_args: Optional[Dict[str, Any]] = None,
|
||||||
|
index_params: Optional[dict] = None,
|
||||||
|
drop_old: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
self.number_of_shards = number_of_shards
|
||||||
|
self.number_of_replicas = number_of_replicas
|
||||||
|
self.embedding_func = embedding_function
|
||||||
|
self.table_name = table_name
|
||||||
|
self.database_name = database_name
|
||||||
|
self.index_params = index_params
|
||||||
|
|
||||||
|
# In order for a collection to be compatible,
|
||||||
|
# 'pk' should be an auto-increment primary key and string
|
||||||
|
self._primary_field = "pk"
|
||||||
|
# In order for compatibility, the text field will need to be called "text"
|
||||||
|
self._text_field = "text"
|
||||||
|
# In order for compatibility, the vector field needs to be called "vector"
|
||||||
|
self._vector_field = "vector"
|
||||||
|
self.fields: List[str] = []
|
||||||
|
# Create the connection to the server
|
||||||
|
if connection_args is None:
|
||||||
|
connection_args = DEFAULT_HIPPO_CONNECTION
|
||||||
|
self.hc = self._create_connection_alias(connection_args)
|
||||||
|
self.col: Any = None
|
||||||
|
|
||||||
|
# If the collection exists, delete it
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
self.hc.check_table_exists(self.table_name, self.database_name)
|
||||||
|
and drop_old
|
||||||
|
):
|
||||||
|
self.hc.delete_table(self.table_name, self.database_name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"An error occurred while deleting the table " f"{self.table_name}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.hc.check_table_exists(self.table_name, self.database_name):
|
||||||
|
self.col = self.hc.get_table(self.table_name, self.database_name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"An error occurred while getting the table " f"{self.table_name}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Initialize the vector database
|
||||||
|
self._get_env()
|
||||||
|
|
||||||
|
def _create_connection_alias(self, connection_args: dict) -> HippoClient:
|
||||||
|
"""Create the connection to the Hippo server."""
|
||||||
|
# Grab the connection arguments that are used for checking existing connection
|
||||||
|
try:
|
||||||
|
from transwarp_hippo_api.hippo_client import HippoClient
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Unable to import transwarp_hipp_api, please install with "
|
||||||
|
"`pip install hippo-api`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
host: str = connection_args.get("host", None)
|
||||||
|
port: int = connection_args.get("port", None)
|
||||||
|
username: str = connection_args.get("username", "shiva")
|
||||||
|
password: str = connection_args.get("password", "shiva")
|
||||||
|
|
||||||
|
# Order of use is host/port, uri, address
|
||||||
|
if host is not None and port is not None:
|
||||||
|
if "," in host:
|
||||||
|
hosts = host.split(",")
|
||||||
|
given_address = ",".join([f"{h}:{port}" for h in hosts])
|
||||||
|
else:
|
||||||
|
given_address = str(host) + ":" + str(port)
|
||||||
|
else:
|
||||||
|
raise ValueError("Missing standard address type for reuse attempt")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"create HippoClient[{given_address}]")
|
||||||
|
return HippoClient([given_address], username=username, pwd=password)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to create new connection")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _get_env(
|
||||||
|
self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None
|
||||||
|
) -> None:
|
||||||
|
logger.info("init ...")
|
||||||
|
if embeddings is not None:
|
||||||
|
logger.info("create collection")
|
||||||
|
self._create_collection(embeddings, metadatas)
|
||||||
|
self._extract_fields()
|
||||||
|
self._create_index()
|
||||||
|
|
||||||
|
def _create_collection(
|
||||||
|
self, embeddings: list, metadatas: Optional[List[dict]] = None
|
||||||
|
) -> None:
|
||||||
|
from transwarp_hippo_api.hippo_client import HippoField
|
||||||
|
from transwarp_hippo_api.hippo_type import HippoType
|
||||||
|
|
||||||
|
# Determine embedding dim
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
logger.debug(f"[_create_collection] dim: {dim}")
|
||||||
|
fields = []
|
||||||
|
|
||||||
|
# Create the primary key field
|
||||||
|
fields.append(HippoField(self._primary_field, True, HippoType.STRING))
|
||||||
|
|
||||||
|
# Create the text field
|
||||||
|
|
||||||
|
fields.append(HippoField(self._text_field, False, HippoType.STRING))
|
||||||
|
|
||||||
|
# Create the vector field, supports binary or float vectors
|
||||||
|
# to The binary vector type is to be developed.
|
||||||
|
fields.append(
|
||||||
|
HippoField(
|
||||||
|
self._vector_field,
|
||||||
|
False,
|
||||||
|
HippoType.FLOAT_VECTOR,
|
||||||
|
type_params={"dimension": dim},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# to In Hippo,there is no method similar to the infer_type_data
|
||||||
|
# types, so currently all non-vector data is converted to string type.
|
||||||
|
|
||||||
|
if metadatas:
|
||||||
|
# # Create FieldSchema for each entry in metadata.
|
||||||
|
for key, value in metadatas[0].items():
|
||||||
|
# # Infer the corresponding datatype of the metadata
|
||||||
|
if isinstance(value, list):
|
||||||
|
value_dim = len(value)
|
||||||
|
fields.append(
|
||||||
|
HippoField(
|
||||||
|
key,
|
||||||
|
False,
|
||||||
|
HippoType.FLOAT_VECTOR,
|
||||||
|
type_params={"dimension": value_dim},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fields.append(HippoField(key, False, HippoType.STRING))
|
||||||
|
|
||||||
|
logger.debug(f"[_create_collection] fields: {fields}")
|
||||||
|
|
||||||
|
# Create the collection
|
||||||
|
self.hc.create_table(
|
||||||
|
name=self.table_name,
|
||||||
|
auto_id=True,
|
||||||
|
fields=fields,
|
||||||
|
database_name=self.database_name,
|
||||||
|
number_of_shards=self.number_of_shards,
|
||||||
|
number_of_replicas=self.number_of_replicas,
|
||||||
|
)
|
||||||
|
self.col = self.hc.get_table(self.table_name, self.database_name)
|
||||||
|
logger.info(
|
||||||
|
f"[_create_collection] : "
|
||||||
|
f"create table {self.table_name} in {self.database_name} successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_fields(self) -> None:
|
||||||
|
"""Grab the existing fields from the Collection"""
|
||||||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||||||
|
|
||||||
|
if isinstance(self.col, HippoTable):
|
||||||
|
schema = self.col.schema
|
||||||
|
logger.debug(f"[_extract_fields] schema:{schema}")
|
||||||
|
for x in schema:
|
||||||
|
self.fields.append(x.name)
|
||||||
|
logger.debug(f"04 [_extract_fields] fields:{self.fields}")
|
||||||
|
|
||||||
|
# TO CAN: Translated into English, your statement would be: "Currently,
|
||||||
|
# only the field named 'vector' (the automatically created vector field)
|
||||||
|
# is checked for indexing. Indexes need to be created manually for other
|
||||||
|
# vector type columns.
|
||||||
|
def _get_index(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Return the vector index information if it exists"""
|
||||||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||||||
|
|
||||||
|
if isinstance(self.col, HippoTable):
|
||||||
|
table_info = self.hc.get_table_info(
|
||||||
|
self.table_name, self.database_name
|
||||||
|
).get(self.table_name, {})
|
||||||
|
embedding_indexes = table_info.get("embedding_indexes", None)
|
||||||
|
if embedding_indexes is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
for x in self.hc.get_table_info(self.table_name, self.database_name)[
|
||||||
|
self.table_name
|
||||||
|
]["embedding_indexes"]:
|
||||||
|
logger.debug(f"[_get_index] embedding_indexes {embedding_indexes}")
|
||||||
|
if x["column"] == self._vector_field:
|
||||||
|
return x
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TO Indexes can only be created for the self._vector_field field.
|
||||||
|
def _create_index(self) -> None:
|
||||||
|
"""Create a index on the collection"""
|
||||||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||||||
|
from transwarp_hippo_api.hippo_type import IndexType, MetricType
|
||||||
|
|
||||||
|
if isinstance(self.col, HippoTable) and self._get_index() is None:
|
||||||
|
if self._get_index() is None:
|
||||||
|
if self.index_params is None:
|
||||||
|
self.index_params = {
|
||||||
|
"index_name": "langchain_auto_create",
|
||||||
|
"metric_type": MetricType.L2,
|
||||||
|
"index_type": IndexType.IVF_FLAT,
|
||||||
|
"nlist": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.col.create_index(
|
||||||
|
self._vector_field,
|
||||||
|
self.index_params["index_name"],
|
||||||
|
self.index_params["index_type"],
|
||||||
|
self.index_params["metric_type"],
|
||||||
|
nlist=self.index_params["nlist"],
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
self.col.activate_index(self.index_params["index_name"])
|
||||||
|
)
|
||||||
|
logger.info("create index successfully")
|
||||||
|
else:
|
||||||
|
index_dict = {
|
||||||
|
"IVF_FLAT": IndexType.IVF_FLAT,
|
||||||
|
"FLAT": IndexType.FLAT,
|
||||||
|
"IVF_SQ": IndexType.IVF_SQ,
|
||||||
|
"IVF_PQ": IndexType.IVF_PQ,
|
||||||
|
"HNSW": IndexType.HNSW,
|
||||||
|
}
|
||||||
|
|
||||||
|
metric_dict = {
|
||||||
|
"ip": MetricType.IP,
|
||||||
|
"IP": MetricType.IP,
|
||||||
|
"l2": MetricType.L2,
|
||||||
|
"L2": MetricType.L2,
|
||||||
|
}
|
||||||
|
self.index_params["metric_type"] = metric_dict[
|
||||||
|
self.index_params["metric_type"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.index_params["index_type"] == "FLAT":
|
||||||
|
self.index_params["index_type"] = index_dict[
|
||||||
|
self.index_params["index_type"]
|
||||||
|
]
|
||||||
|
self.col.create_index(
|
||||||
|
self._vector_field,
|
||||||
|
self.index_params["index_name"],
|
||||||
|
self.index_params["index_type"],
|
||||||
|
self.index_params["metric_type"],
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
self.col.activate_index(self.index_params["index_name"])
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
self.index_params["index_type"] == "IVF_FLAT"
|
||||||
|
or self.index_params["index_type"] == "IVF_SQ"
|
||||||
|
):
|
||||||
|
self.index_params["index_type"] = index_dict[
|
||||||
|
self.index_params["index_type"]
|
||||||
|
]
|
||||||
|
self.col.create_index(
|
||||||
|
self._vector_field,
|
||||||
|
self.index_params["index_name"],
|
||||||
|
self.index_params["index_type"],
|
||||||
|
self.index_params["metric_type"],
|
||||||
|
nlist=self.index_params.get("nlist", 10),
|
||||||
|
nprobe=self.index_params.get("nprobe", 10),
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
self.col.activate_index(self.index_params["index_name"])
|
||||||
|
)
|
||||||
|
elif self.index_params["index_type"] == "IVF_PQ":
|
||||||
|
self.index_params["index_type"] = index_dict[
|
||||||
|
self.index_params["index_type"]
|
||||||
|
]
|
||||||
|
self.col.create_index(
|
||||||
|
self._vector_field,
|
||||||
|
self.index_params["index_name"],
|
||||||
|
self.index_params["index_type"],
|
||||||
|
self.index_params["metric_type"],
|
||||||
|
nlist=self.index_params.get("nlist", 10),
|
||||||
|
nprobe=self.index_params.get("nprobe", 10),
|
||||||
|
nbits=self.index_params.get("nbits", 8),
|
||||||
|
m=self.index_params.get("m"),
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
self.col.activate_index(self.index_params["index_name"])
|
||||||
|
)
|
||||||
|
elif self.index_params["index_type"] == "HNSW":
|
||||||
|
self.index_params["index_type"] = index_dict[
|
||||||
|
self.index_params["index_type"]
|
||||||
|
]
|
||||||
|
self.col.create_index(
|
||||||
|
self._vector_field,
|
||||||
|
self.index_params["index_name"],
|
||||||
|
self.index_params["index_type"],
|
||||||
|
self.index_params["metric_type"],
|
||||||
|
M=self.index_params.get("M"),
|
||||||
|
ef_construction=self.index_params.get("ef_construction"),
|
||||||
|
ef_search=self.index_params.get("ef_search"),
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
self.col.activate_index(self.index_params["index_name"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Index name does not match, "
|
||||||
|
"please enter the correct index name. "
|
||||||
|
"(FLAT, IVF_FLAT, IVF_PQ,IVF_SQ, HNSW)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
batch_size: int = 1000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Add text to the collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: An iterable that contains the text to be added.
|
||||||
|
metadatas: An optional list of dictionaries,
|
||||||
|
each dictionary contains the metadata associated with a text.
|
||||||
|
timeout: Optional timeout, in seconds.
|
||||||
|
batch_size: The number of texts inserted in each batch, defaults to 1000.
|
||||||
|
**kwargs: Other optional parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings, containing the unique identifiers of the inserted texts.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the collection has not yet been created,
|
||||||
|
this method will create a new collection.
|
||||||
|
"""
|
||||||
|
from transwarp_hippo_api.hippo_client import HippoTable
|
||||||
|
|
||||||
|
if not texts or all(t == "" for t in texts):
|
||||||
|
logger.debug("Nothing to insert, skipping.")
|
||||||
|
return []
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
logger.debug(f"[add_texts] texts: {texts}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = self.embedding_func.embed_documents(texts)
|
||||||
|
except NotImplementedError:
|
||||||
|
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||||
|
|
||||||
|
if len(embeddings) == 0:
|
||||||
|
logger.debug("Nothing to insert, skipping.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug(f"[add_texts] len_embeddings:{len(embeddings)}")
|
||||||
|
|
||||||
|
# 如果还没有创建collection则创建collection
|
||||||
|
if not isinstance(self.col, HippoTable):
|
||||||
|
self._get_env(embeddings, metadatas)
|
||||||
|
|
||||||
|
# Dict to hold all insert columns
|
||||||
|
insert_dict: Dict[str, list] = {
|
||||||
|
self._text_field: texts,
|
||||||
|
self._vector_field: embeddings,
|
||||||
|
}
|
||||||
|
logger.debug(f"[add_texts] metadatas:{metadatas}")
|
||||||
|
logger.debug(f"[add_texts] fields:{self.fields}")
|
||||||
|
if metadatas is not None:
|
||||||
|
for d in metadatas:
|
||||||
|
for key, value in d.items():
|
||||||
|
if key in self.fields:
|
||||||
|
insert_dict.setdefault(key, []).append(value)
|
||||||
|
|
||||||
|
logger.debug(insert_dict[self._text_field])
|
||||||
|
|
||||||
|
# Total insert count
|
||||||
|
vectors: list = insert_dict[self._vector_field]
|
||||||
|
total_count = len(vectors)
|
||||||
|
|
||||||
|
if "pk" in self.fields:
|
||||||
|
self.fields.remove("pk")
|
||||||
|
|
||||||
|
logger.debug(f"[add_texts] total_count:{total_count}")
|
||||||
|
for i in range(0, total_count, batch_size):
|
||||||
|
# Grab end index
|
||||||
|
end = min(i + batch_size, total_count)
|
||||||
|
# Convert dict to list of lists batch for insertion
|
||||||
|
insert_list = [insert_dict[x][i:end] for x in self.fields]
|
||||||
|
try:
|
||||||
|
res = self.col.insert_rows(insert_list)
|
||||||
|
logger.info(f"05 [add_texts] insert {res}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
return [""]
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Perform a similarity search on the query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The text to search for.
|
||||||
|
k (int, optional): The number of results to return. Default is 4.
|
||||||
|
param (dict, optional): Specifies the search parameters for the index.
|
||||||
|
Defaults to None.
|
||||||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||||||
|
timeout (int, optional): Time to wait before a timeout error.
|
||||||
|
Defaults to None.
|
||||||
|
kwargs: Keyword arguments for Collection.search().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: The document results of the search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
res = self.similarity_search_with_score(
|
||||||
|
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||||
|
)
|
||||||
|
return [doc for doc, _ in res]
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""
|
||||||
|
Performs a search on the query string and returns results with scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The text being searched.
|
||||||
|
k (int, optional): The number of results to return.
|
||||||
|
Default is 4.
|
||||||
|
param (dict): Specifies the search parameters for the index.
|
||||||
|
Default is None.
|
||||||
|
expr (str, optional): Filtering expression. Default is None.
|
||||||
|
timeout (int, optional): The waiting time before a timeout error.
|
||||||
|
Default is None.
|
||||||
|
kwargs: Keyword arguments for Collection.search().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float], List[Tuple[Document, any, any]]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Embed the query text.
|
||||||
|
embedding = self.embedding_func.embed_query(query)
|
||||||
|
|
||||||
|
ret = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def similarity_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""
|
||||||
|
Performs a search on the query string and returns results with scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding (List[float]): The embedding vector being searched.
|
||||||
|
k (int, optional): The number of results to return.
|
||||||
|
Default is 4.
|
||||||
|
param (dict): Specifies the search parameters for the index.
|
||||||
|
Default is None.
|
||||||
|
expr (str, optional): Filtering expression. Default is None.
|
||||||
|
timeout (int, optional): The waiting time before a timeout error.
|
||||||
|
Default is None.
|
||||||
|
kwargs: Keyword arguments for Collection.search().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[Document, float]]: Resulting documents and scores.
|
||||||
|
"""
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# if param is None:
|
||||||
|
# param = self.search_params
|
||||||
|
|
||||||
|
# Determine result metadata fields.
|
||||||
|
output_fields = self.fields[:]
|
||||||
|
output_fields.remove(self._vector_field)
|
||||||
|
|
||||||
|
# Perform the search.
|
||||||
|
logger.debug(f"search_field:{self._vector_field}")
|
||||||
|
logger.debug(f"vectors:{[embedding]}")
|
||||||
|
logger.debug(f"output_fields:{output_fields}")
|
||||||
|
logger.debug(f"topk:{k}")
|
||||||
|
logger.debug(f"dsl:{expr}")
|
||||||
|
|
||||||
|
res = self.col.query(
|
||||||
|
search_field=self._vector_field,
|
||||||
|
vectors=[embedding],
|
||||||
|
output_fields=output_fields,
|
||||||
|
topk=k,
|
||||||
|
dsl=expr,
|
||||||
|
)
|
||||||
|
# Organize results.
|
||||||
|
logger.debug(f"[similarity_search_with_score_by_vector] res:{res}")
|
||||||
|
score_col = self._text_field + "%scores"
|
||||||
|
ret = []
|
||||||
|
count = 0
|
||||||
|
for items in zip(*[res[0][field] for field in output_fields]):
|
||||||
|
meta = {field: value for field, value in zip(output_fields, items)}
|
||||||
|
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
|
||||||
|
logger.debug(
|
||||||
|
f"[similarity_search_with_score_by_vector] "
|
||||||
|
f"res[0][score_col]:{res[0][score_col]}"
|
||||||
|
)
|
||||||
|
score = res[0][score_col][count]
|
||||||
|
count += 1
|
||||||
|
ret.append((doc, score))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
table_name: str = "test",
|
||||||
|
database_name: str = "default",
|
||||||
|
connection_args: Dict[str, Any] = DEFAULT_HIPPO_CONNECTION,
|
||||||
|
index_params: Optional[Dict[Any, Any]] = None,
|
||||||
|
search_params: Optional[Dict[str, Any]] = None,
|
||||||
|
drop_old: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "Hippo":
|
||||||
|
"""
|
||||||
|
Creates an instance of the VST class from the given texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): List of texts to be added.
|
||||||
|
embedding (Embeddings): Embedding model for the texts.
|
||||||
|
metadatas (List[dict], optional):
|
||||||
|
List of metadata dictionaries for each text.Defaults to None.
|
||||||
|
table_name (str): Name of the table. Defaults to "test".
|
||||||
|
database_name (str): Name of the database. Defaults to "default".
|
||||||
|
connection_args (dict[str, Any]): Connection parameters.
|
||||||
|
Defaults to DEFAULT_HIPPO_CONNECTION.
|
||||||
|
index_params (dict): Indexing parameters. Defaults to None.
|
||||||
|
search_params (dict): Search parameters. Defaults to an empty dictionary.
|
||||||
|
drop_old (bool): Whether to drop the old collection. Defaults to False.
|
||||||
|
kwargs: Other arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hippo: An instance of the VST class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if search_params is None:
|
||||||
|
search_params = {}
|
||||||
|
logger.info("00 [from_texts] init the class of Hippo")
|
||||||
|
vector_db = cls(
|
||||||
|
embedding_function=embedding,
|
||||||
|
table_name=table_name,
|
||||||
|
database_name=database_name,
|
||||||
|
connection_args=connection_args,
|
||||||
|
index_params=index_params,
|
||||||
|
drop_old=drop_old,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
logger.debug(f"[from_texts] texts:{texts}")
|
||||||
|
logger.debug(f"[from_texts] metadatas:{metadatas}")
|
||||||
|
vector_db.add_texts(texts=texts, metadatas=metadatas)
|
||||||
|
return vector_db
|
@ -0,0 +1,63 @@
|
|||||||
|
"""Test Hippo functionality."""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.vectorstores.hippo import Hippo
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
|
FakeEmbeddings,
|
||||||
|
fake_texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _hippo_from_texts(
|
||||||
|
metadatas: Optional[List[dict]] = None, drop: bool = True
|
||||||
|
) -> Hippo:
|
||||||
|
return Hippo.from_texts(
|
||||||
|
fake_texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
table_name="langchain_test",
|
||||||
|
connection_args={"host": "127.0.0.1", "port": 7788},
|
||||||
|
drop_old=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hippo_add_extra() -> None:
|
||||||
|
"""Test end to end construction and MRR search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _hippo_from_texts(metadatas=metadatas)
|
||||||
|
|
||||||
|
docsearch.add_texts(texts, metadatas)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
print(output)
|
||||||
|
assert len(output) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_hippo() -> None:
|
||||||
|
docsearch = _hippo_from_texts()
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_hippo_with_score() -> None:
|
||||||
|
"""Test end to end construction and search with scores and IDs."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _hippo_from_texts(metadatas=metadatas)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||||
|
docs = [o[0] for o in output]
|
||||||
|
scores = [o[1] for o in output]
|
||||||
|
assert docs == [
|
||||||
|
Document(page_content="foo", metadata={"page": "0"}),
|
||||||
|
Document(page_content="bar", metadata={"page": "1"}),
|
||||||
|
Document(page_content="baz", metadata={"page": "2"}),
|
||||||
|
]
|
||||||
|
assert scores[0] < scores[1] < scores[2]
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# test_hippo()
|
||||||
|
# test_hippo_with_score()
|
||||||
|
# test_hippo_with_score()
|
Loading…
Reference in New Issue
Block a user