mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
7 Commits
isaac/mess
...
dev2049/pe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfad58c29e | ||
|
|
658cee5670 | ||
|
|
de0bf1a3a4 | ||
|
|
d7658f6f01 | ||
|
|
0c8be8fe23 | ||
|
|
c7c99d7cfd | ||
|
|
f06b2ac495 |
@@ -0,0 +1,267 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "68374fbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Top 50 Spotify music dataset\n",
|
||||
"song_df = pd.read_csv(\"https://gist.githubusercontent.com/rioto9858/ff72b72b3bf5754d29dd1ebf898fc893/raw/1164a139a780b0826faef36c865da65f2d3573e0/top50MusicFrom2010-2019.csv\")\n",
|
||||
"docs = [\n",
|
||||
" Document(page_content=t, metadata={\"artist\": a, \"genre\": g, \"year\": y})\n",
|
||||
" for _, (t, a, g, y) in song_df[[\"title\", \"artist\", \"the genre of the track\", \"year\"]].sample(100).iterrows()\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bcbe04d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pinecone\n",
|
||||
"\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import Pinecone\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# You'll need to sign up for Pinecone and get your credentials to run this demo.\n",
|
||||
"pinecone.init(api_key=os.environ[\"PINECONE_API_KEY\"], environment=os.environ[\"PINECONE_ENV\"])\n",
|
||||
"pinecone.create_index(\"langchain-self-retriever-demo\", dimension=1536)\n",
|
||||
"vectorstore = Pinecone.from_documents(docs, OpenAIEmbeddings(), index_name=\"langchain-self-retriever-demo\")\n",
|
||||
"\n",
|
||||
"# # If you've already created the index:\n",
|
||||
"# vectorstore = Pinecone.from_existing_index(index_name=\"langchain-self-retriever-demo\", embedding=OpenAIEmbeddings())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fbd392df",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Self querying"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ab1128d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.retrievers.pinecone_self_query import (\n",
|
||||
" MetadataFieldInfo,\n",
|
||||
" PineconeSelfQueryRetriever, \n",
|
||||
" VectorStoreExtendedInfo, \n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "526b0aa2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore_info = VectorStoreExtendedInfo(\n",
|
||||
" vectorstore=vectorstore, \n",
|
||||
" name=\"Top 50 Spotify Songs\", \n",
|
||||
" description=\"The most popular songs on Spotify\", \n",
|
||||
" metadata_field_info=[\n",
|
||||
" MetadataFieldInfo(\n",
|
||||
" name=\"artist\",\n",
|
||||
" description=\"The artist who released the song\", \n",
|
||||
" type=\"string\", \n",
|
||||
" examples=song_df['artist'].sample(3).tolist()\n",
|
||||
" ),\n",
|
||||
" MetadataFieldInfo(\n",
|
||||
" name=\"genre\",\n",
|
||||
" description=\"The genre of the song\", \n",
|
||||
" type=\"string\", \n",
|
||||
" examples=song_df['the genre of the track'].sample(3).tolist()\n",
|
||||
" ),\n",
|
||||
" MetadataFieldInfo(\n",
|
||||
" name=\"year\",\n",
|
||||
" description=\"The year the song was released\", \n",
|
||||
" type=\"integer\", \n",
|
||||
" examples=song_df['year'].sample(3).tolist()\n",
|
||||
" ),\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"id": "c8e840f0",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = PineconeSelfQueryRetriever.from_vectorstore_info(OpenAI(temperature=0), vectorstore_info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "554171d1",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'search_string': 'love', 'metadata_filter': {}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Love', metadata={'artist': 'Lana Del Rey', 'genre': 'art pop', 'year': 2017.0}),\n",
|
||||
" Document(page_content='L.A.LOVE (la la)', metadata={'artist': 'Fergie', 'genre': 'dance pop', 'year': 2015.0}),\n",
|
||||
" Document(page_content='human', metadata={'artist': 'Christina Perri', 'genre': 'dance pop', 'year': 2014.0}),\n",
|
||||
" Document(page_content='Someone You Loved', metadata={'artist': 'Lewis Capaldi', 'genre': 'pop', 'year': 2019.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(\"What are some songs about love\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "96e08d26",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'search_string': ' ', 'metadata_filter': {'genre': {'$eq': 'pop'}, 'year': {'$eq': 2012}}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='I Knew You Were Trouble.', metadata={'artist': 'Taylor Swift', 'genre': 'pop', 'year': 2012.0}),\n",
|
||||
" Document(page_content='One More Night', metadata={'artist': 'Maroon 5', 'genre': 'pop', 'year': 2012.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(\"What are some popular pop songs from 2012\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"id": "2651d2f6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'search_string': 'money OR success', 'metadata_filter': {'genre': {'$in': ['dance pop', 'art pop']}, 'year': {'$gte': 2015}}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Love', metadata={'artist': 'Lana Del Rey', 'genre': 'art pop', 'year': 2017.0}),\n",
|
||||
" Document(page_content='Confident', metadata={'artist': 'Demi Lovato', 'genre': 'dance pop', 'year': 2016.0}),\n",
|
||||
" Document(page_content='Up', metadata={'artist': 'Olly Murs', 'genre': 'dance pop', 'year': 2015.0}),\n",
|
||||
" Document(page_content='Booty', metadata={'artist': 'Jennifer Lopez', 'genre': 'dance pop', 'year': 2015.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(\"What are some dance or art pop songs that mention money or success after 2015\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"id": "e0f64adc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'search_string': 'Maroon 5', 'metadata_filter': {'artist': {'$eq': 'Maroon 5'}, 'genre': {'$eq': 'rap'}}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[]"
|
||||
]
|
||||
},
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(\"Did Maroon 5 release any popular rap songs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d42fc1bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -22,6 +22,27 @@ def _get_sub_string(schema: ResponseSchema) -> str:
|
||||
)
|
||||
|
||||
|
||||
def parse_json_markdown(text: str, expected_keys: List[str]) -> Any:
|
||||
if "```json" not in text:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected markdown code snippet with JSON "
|
||||
f"object, but got:\n{text}"
|
||||
)
|
||||
|
||||
json_string = text.split("```json")[1].strip().strip("```").strip()
|
||||
try:
|
||||
json_obj = json.loads(json_string)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{key}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
||||
|
||||
class StructuredOutputParser(BaseOutputParser):
|
||||
response_schemas: List[ResponseSchema]
|
||||
|
||||
@@ -38,24 +59,8 @@ class StructuredOutputParser(BaseOutputParser):
|
||||
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
if "```json" not in text:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected markdown code snippet with JSON "
|
||||
f"object, but got:\n{text}"
|
||||
)
|
||||
|
||||
json_string = text.split("```json")[1].strip().strip("```").strip()
|
||||
try:
|
||||
json_obj = json.loads(json_string)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for schema in self.response_schemas:
|
||||
if schema.name not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{schema.name}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
expected_keys = [rs.name for rs in self.response_schemas]
|
||||
return parse_json_markdown(text, expected_keys)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
|
||||
133
langchain/retrievers/pinecone_self_query.py
Normal file
133
langchain/retrievers/pinecone_self_query.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.agents.agent_toolkits import VectorStoreInfo
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.retrievers.pinecone_self_query_prompt import (
|
||||
PineconeSelfQueryOutputParser,
|
||||
pinecone_example,
|
||||
pinecone_format_instructions,
|
||||
self_query_prompt,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores import Pinecone
|
||||
|
||||
|
||||
class MetadataFieldInfo(BaseModel):
|
||||
"""Information about a vectorstore metadata field."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
examples: List
|
||||
type: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class VectorStoreExtendedInfo(VectorStoreInfo):
|
||||
"""Extension of VectorStoreInfo that includes info about metadata fields."""
|
||||
|
||||
metadata_field_info: List[MetadataFieldInfo]
|
||||
"""Map of metadata field name to info about that field."""
|
||||
|
||||
|
||||
def _format_metadata_field_info(info: List[MetadataFieldInfo]) -> str:
|
||||
info_dicts = {}
|
||||
for i in info:
|
||||
i_dict = dict(i)
|
||||
info_dicts[i_dict.pop("name")] = i_dict
|
||||
return json.dumps(info_dicts, indent=2).replace("{", "{{").replace("}", "}}")
|
||||
|
||||
|
||||
class PineconeSelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
"""Retriever that wraps around a Pinecone vector store and uses an LLM to generate
|
||||
the vector store queries."""
|
||||
|
||||
vectorstore: Pinecone
|
||||
"""The Pinecone vector store from which documents will be retrieved."""
|
||||
llm_chain: LLMChain
|
||||
"""The LLMChain for generating the vector store queries."""
|
||||
search_type: str = "similarity"
|
||||
"""The search type to perform on the vector store."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass in to the vector store search."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
if search_type not in ("similarity", "mmr"):
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity' or 'mmr'."
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
inputs = self.llm_chain.prep_inputs(query)
|
||||
vectorstore_query = cast(dict, self.llm_chain.predict_and_parse(**inputs))
|
||||
print(vectorstore_query)
|
||||
new_query = vectorstore_query["search_string"]
|
||||
_filter = vectorstore_query["metadata_filter"]
|
||||
docs = self.vectorstore.search(
|
||||
new_query, self.search_type, filter=_filter, **self.search_kwargs
|
||||
)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
inputs = self.llm_chain.prep_inputs(query)
|
||||
vectorstore_query = cast(dict, self.llm_chain.apredict_and_parse(**inputs))
|
||||
new_query = vectorstore_query["query"]
|
||||
_filter = vectorstore_query["filter"]
|
||||
docs = await self.vectorstore.asearch(
|
||||
new_query, self.search_type, filter=_filter, **self.search_kwargs
|
||||
)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_vectorstore_info(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
vectorstore_info: VectorStoreExtendedInfo,
|
||||
**kwargs: Any,
|
||||
) -> "PineconeSelfQueryRetriever":
|
||||
metadata_field_json = _format_metadata_field_info(
|
||||
vectorstore_info.metadata_field_info
|
||||
)
|
||||
prompt_str = self_query_prompt.format(
|
||||
format_instructions=pinecone_format_instructions,
|
||||
example=pinecone_example,
|
||||
docstore_description=vectorstore_info.description,
|
||||
metadata_fields=metadata_field_json,
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=prompt_str,
|
||||
output_parser=PineconeSelfQueryOutputParser(),
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
vectorstore=vectorstore_info.vectorstore,
|
||||
**kwargs,
|
||||
)
|
||||
108
langchain/retrievers/pinecone_self_query_prompt.py
Normal file
108
langchain/retrievers/pinecone_self_query_prompt.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# flake8: noqa
|
||||
from typing import Dict
|
||||
|
||||
from langchain.output_parsers.structured import parse_json_markdown
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
pinecone_format_instructions = """RESPONSE FORMAT
|
||||
----------------------------
|
||||
|
||||
When responding use a markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"search_string": string \\ The string to compare to document contents,
|
||||
"metadata_filter": {{
|
||||
"<MetadataField>": {{
|
||||
<Operator>: <Value>
|
||||
}},
|
||||
"<MetadataField>": {{
|
||||
<Operator>: <Value>
|
||||
}},
|
||||
}} \\ The conditions on which to filter the metadata
|
||||
}}
|
||||
```
|
||||
Filtering metadata supports the following operators:
|
||||
$eq - Equal to (number, string, boolean)
|
||||
$ne - Not equal to (number, string, boolean)
|
||||
$gt - Greater than (number)
|
||||
$gte - Greater than or equal to (number)
|
||||
$lt - Less than (number)
|
||||
$lte - Less than or equal to (number)
|
||||
$in - In array (string or number)
|
||||
$nin - Not in array (string or number)
|
||||
|
||||
NOTE that if you are not exactly sure how some string metadata values are formatted you can include multiple potentially matching values. For example if you're not sure how a string value is capitalized, you can check for equality with both the fully upper-cased or fully lower-cased versions of the value.
|
||||
|
||||
PLEASE REMEMBER that for some queries no metadata filters are needed. In these cases you should leave "metadata_filter" as an empty map."""
|
||||
pinecone_example = """EXAMPLE
|
||||
----------------------------
|
||||
|
||||
DOCUMENT STORE DESCRIPTION: News headlines from around the world
|
||||
METADATA FIELDS: {{
|
||||
"year": {{
|
||||
"description": "The year the headline was published",
|
||||
"type": "integer",
|
||||
"example_values": [2022, 1997]
|
||||
}},
|
||||
"country": {{
|
||||
"description": "The country of origin of the news media outlet",
|
||||
"type": "string",
|
||||
"example_values": ["Chile", "Japan", "Ghana"]
|
||||
}},
|
||||
"source": {{
|
||||
"description": "The name of the news media outlet",
|
||||
"type": "string",
|
||||
"example_values": ["Wall Street Journal", "New York Times", "Axios"]
|
||||
}}
|
||||
}}
|
||||
QUESTION: What was the sentiment of Mexican and Taiwanese media outlets regarding the 2024 trade deal between America and India?
|
||||
DOCUMENT STORE QUERY:
|
||||
```json
|
||||
{{
|
||||
"search_string": "Trade deal between America and India",
|
||||
"metadata_filter": {{
|
||||
"country": {{
|
||||
"$in": ["Mexico", "United Mexican States", "mexico", "Taiwan", "Republic of China", "ROC", "taiwan"]
|
||||
}},
|
||||
"year": {{
|
||||
"$gte": 2024
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
```"""
|
||||
self_query_prompt = """INSTRUCTIONS
|
||||
----------------------------
|
||||
|
||||
You have access to a store of documents. Each document contains text and a key-value store of associated metadata. Given a user question, your job is to come up with a fully formed query to the document store that will return the most relevant documents.
|
||||
|
||||
A document store query consists of two components: a search string and a metadata filter. The search string is compared to the text contents of the stored documents. The metadata filter is used to filter out documents whose metadata does not match the given criteria.
|
||||
|
||||
|
||||
{format_instructions}
|
||||
|
||||
|
||||
{example}
|
||||
|
||||
Begin!
|
||||
|
||||
DOCUMENT STORE DESCRIPTION: {docstore_description}
|
||||
METADATA FIELDS: {metadata_fields}
|
||||
QUESTION: {{question}}
|
||||
DOCUMENT STORE QUERY:"""
|
||||
|
||||
|
||||
class PineconeSelfQueryOutputParser(BaseOutputParser[Dict]):
|
||||
def get_format_instructions(self) -> str:
|
||||
return pinecone_format_instructions
|
||||
|
||||
def parse(self, text: str) -> Dict:
|
||||
expected_keys = ["search_string", "metadata_filter"]
|
||||
parsed = parse_json_markdown(text, expected_keys)
|
||||
if len(parsed["search_string"]) == 0:
|
||||
parsed["search_string"] = " "
|
||||
return parsed
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "pinecone_self_query_output_parser"
|
||||
@@ -75,6 +75,32 @@ class VectorStore(ABC):
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return await self.aadd_texts(texts, metadatas, **kwargs)
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
"""Return docs most similar to query using specified search type."""
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "mmr":
|
||||
return self.max_marginal_relevance_search(query, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity' or 'mmr'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query using specified search type."""
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "mmr":
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity' or 'mmr'."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
|
||||
@@ -243,7 +243,7 @@ class DeepLake(VectorStore):
|
||||
self.ds.summary()
|
||||
return ids
|
||||
|
||||
def search(
|
||||
def _search_helper(
|
||||
self,
|
||||
query: Any[str, None] = None,
|
||||
embedding: Any[float, None] = None,
|
||||
@@ -366,7 +366,7 @@ class DeepLake(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
return self.search(query=query, k=k, **kwargs)
|
||||
return self._search_helper(query=query, k=k, **kwargs)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
@@ -379,7 +379,7 @@ class DeepLake(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
return self.search(embedding=embedding, k=k, **kwargs)
|
||||
return self._search_helper(embedding=embedding, k=k, **kwargs)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
@@ -401,7 +401,7 @@ class DeepLake(VectorStore):
|
||||
List[Tuple[Document, float]]: List of documents most similar to the query
|
||||
text with distance in float.
|
||||
"""
|
||||
return self.search(
|
||||
return self._search_helper(
|
||||
query=query,
|
||||
k=k,
|
||||
filter=filter,
|
||||
@@ -431,7 +431,7 @@ class DeepLake(VectorStore):
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
return self.search(
|
||||
return self._search_helper(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
@@ -465,7 +465,7 @@ class DeepLake(VectorStore):
|
||||
raise ValueError(
|
||||
"For MMR search, you must specify an embedding function on" "creation."
|
||||
)
|
||||
return self.search(
|
||||
return self._search_helper(
|
||||
query=query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
|
||||
Reference in New Issue
Block a user