Compare commits

...

7 Commits

Author SHA1 Message Date
Dev 2049
dfad58c29e cr 2023-04-24 17:43:27 -07:00
Dev 2049
658cee5670 wip 2023-04-24 16:40:27 -07:00
Dev 2049
de0bf1a3a4 Merge branch 'master' into dev2049/perfect_retriever 2023-04-24 13:36:10 -07:00
Dev 2049
d7658f6f01 wip 2023-04-14 16:29:46 -07:00
Dev 2049
0c8be8fe23 Merge branch 'master' into dev2049/perfect_retriever 2023-04-14 15:04:17 -07:00
Dev 2049
c7c99d7cfd format 2023-04-13 19:52:57 -07:00
Dev 2049
f06b2ac495 rfc 2023-04-13 19:52:11 -07:00
6 changed files with 563 additions and 24 deletions

View File

@@ -0,0 +1,267 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 24,
"id": "68374fbd",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from langchain.schema import Document\n",
"\n",
"\n",
"# Top 50 Spotify music dataset\n",
"song_df = pd.read_csv(\"https://gist.githubusercontent.com/rioto9858/ff72b72b3bf5754d29dd1ebf898fc893/raw/1164a139a780b0826faef36c865da65f2d3573e0/top50MusicFrom2010-2019.csv\")\n",
"docs = [\n",
" Document(page_content=t, metadata={\"artist\": a, \"genre\": g, \"year\": y})\n",
" for _, (t, a, g, y) in song_df[[\"title\", \"artist\", \"the genre of the track\", \"year\"]].sample(100).iterrows()\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcbe04d9",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import pinecone\n",
"\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.vectorstores import Pinecone\n",
"\n",
"\n",
"# You'll need to sign up for Pinecone and get your credentials to run this demo.\n",
"pinecone.init(api_key=os.environ[\"PINECONE_API_KEY\"], environment=os.environ[\"PINECONE_ENV\"])\n",
"pinecone.create_index(\"langchain-self-retriever-demo\", dimension=1536)\n",
"vectorstore = Pinecone.from_documents(docs, OpenAIEmbeddings(), index_name=\"langchain-self-retriever-demo\")\n",
"\n",
"# # If you've already created the index:\n",
"# vectorstore = Pinecone.from_existing_index(index_name=\"langchain-self-retriever-demo\", embedding=OpenAIEmbeddings())"
]
},
{
"cell_type": "markdown",
"id": "fbd392df",
"metadata": {},
"source": [
"# Self querying"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ab1128d7",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.retrievers.pinecone_self_query import (\n",
" MetadataFieldInfo,\n",
" PineconeSelfQueryRetriever, \n",
" VectorStoreExtendedInfo, \n",
")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "526b0aa2",
"metadata": {},
"outputs": [],
"source": [
"vectorstore_info = VectorStoreExtendedInfo(\n",
" vectorstore=vectorstore, \n",
" name=\"Top 50 Spotify Songs\", \n",
" description=\"The most popular songs on Spotify\", \n",
" metadata_field_info=[\n",
" MetadataFieldInfo(\n",
" name=\"artist\",\n",
" description=\"The artist who released the song\", \n",
" type=\"string\", \n",
" examples=song_df['artist'].sample(3).tolist()\n",
" ),\n",
" MetadataFieldInfo(\n",
" name=\"genre\",\n",
" description=\"The genre of the song\", \n",
" type=\"string\", \n",
" examples=song_df['the genre of the track'].sample(3).tolist()\n",
" ),\n",
" MetadataFieldInfo(\n",
" name=\"year\",\n",
" description=\"The year the song was released\", \n",
" type=\"integer\", \n",
" examples=song_df['year'].sample(3).tolist()\n",
" ),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "c8e840f0",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"retriever = PineconeSelfQueryRetriever.from_vectorstore_info(OpenAI(temperature=0), vectorstore_info)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "554171d1",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'search_string': 'love', 'metadata_filter': {}}\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='Love', metadata={'artist': 'Lana Del Rey', 'genre': 'art pop', 'year': 2017.0}),\n",
" Document(page_content='L.A.LOVE (la la)', metadata={'artist': 'Fergie', 'genre': 'dance pop', 'year': 2015.0}),\n",
" Document(page_content='human', metadata={'artist': 'Christina Perri', 'genre': 'dance pop', 'year': 2014.0}),\n",
" Document(page_content='Someone You Loved', metadata={'artist': 'Lewis Capaldi', 'genre': 'pop', 'year': 2019.0})]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.get_relevant_documents(\"What are some songs about love\")"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "96e08d26",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'search_string': ' ', 'metadata_filter': {'genre': {'$eq': 'pop'}, 'year': {'$eq': 2012}}}\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='I Knew You Were Trouble.', metadata={'artist': 'Taylor Swift', 'genre': 'pop', 'year': 2012.0}),\n",
" Document(page_content='One More Night', metadata={'artist': 'Maroon 5', 'genre': 'pop', 'year': 2012.0})]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.get_relevant_documents(\"What are some popular pop songs from 2012\")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "2651d2f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'search_string': 'money OR success', 'metadata_filter': {'genre': {'$in': ['dance pop', 'art pop']}, 'year': {'$gte': 2015}}}\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='Love', metadata={'artist': 'Lana Del Rey', 'genre': 'art pop', 'year': 2017.0}),\n",
" Document(page_content='Confident', metadata={'artist': 'Demi Lovato', 'genre': 'dance pop', 'year': 2016.0}),\n",
" Document(page_content='Up', metadata={'artist': 'Olly Murs', 'genre': 'dance pop', 'year': 2015.0}),\n",
" Document(page_content='Booty', metadata={'artist': 'Jennifer Lopez', 'genre': 'dance pop', 'year': 2015.0})]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.get_relevant_documents(\"What are some dance or art pop songs that mention money or success after 2015\")"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "e0f64adc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'search_string': 'Maroon 5', 'metadata_filter': {'artist': {'$eq': 'Maroon 5'}, 'genre': {'$eq': 'rap'}}}\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.get_relevant_documents(\"Did Maroon 5 release any popular rap songs\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d42fc1bd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -22,6 +22,27 @@ def _get_sub_string(schema: ResponseSchema) -> str:
)
def parse_json_markdown(text: str, expected_keys: List[str]) -> Any:
if "```json" not in text:
raise OutputParserException(
f"Got invalid return object. Expected markdown code snippet with JSON "
f"object, but got:\n{text}"
)
json_string = text.split("```json")[1].strip().strip("```").strip()
try:
json_obj = json.loads(json_string)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj
class StructuredOutputParser(BaseOutputParser):
response_schemas: List[ResponseSchema]
@@ -38,24 +59,8 @@ class StructuredOutputParser(BaseOutputParser):
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> Any:
if "```json" not in text:
raise OutputParserException(
f"Got invalid return object. Expected markdown code snippet with JSON "
f"object, but got:\n{text}"
)
json_string = text.split("```json")[1].strip().strip("```").strip()
try:
json_obj = json.loads(json_string)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for schema in self.response_schemas:
if schema.name not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)
return json_obj
expected_keys = [rs.name for rs in self.response_schemas]
return parse_json_markdown(text, expected_keys)
@property
def _type(self) -> str:

View File

@@ -0,0 +1,133 @@
import json
from typing import Any, Dict, List, cast
from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain, PromptTemplate
from langchain.agents.agent_toolkits import VectorStoreInfo
from langchain.llms import BaseLLM
from langchain.retrievers.pinecone_self_query_prompt import (
PineconeSelfQueryOutputParser,
pinecone_example,
pinecone_format_instructions,
self_query_prompt,
)
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import Pinecone
class MetadataFieldInfo(BaseModel):
"""Information about a vectorstore metadata field."""
name: str
description: str
examples: List
type: str
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
class VectorStoreExtendedInfo(VectorStoreInfo):
"""Extension of VectorStoreInfo that includes info about metadata fields."""
metadata_field_info: List[MetadataFieldInfo]
"""Map of metadata field name to info about that field."""
def _format_metadata_field_info(info: List[MetadataFieldInfo]) -> str:
info_dicts = {}
for i in info:
i_dict = dict(i)
info_dicts[i_dict.pop("name")] = i_dict
return json.dumps(info_dicts, indent=2).replace("{", "{{").replace("}", "}}")
class PineconeSelfQueryRetriever(BaseRetriever, BaseModel):
"""Retriever that wraps around a Pinecone vector store and uses an LLM to generate
the vector store queries."""
vectorstore: Pinecone
"""The Pinecone vector store from which documents will be retrieved."""
llm_chain: LLMChain
"""The LLMChain for generating the vector store queries."""
search_type: str = "similarity"
"""The search type to perform on the vector store."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass in to the vector store search."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs(query)
vectorstore_query = cast(dict, self.llm_chain.predict_and_parse(**inputs))
print(vectorstore_query)
new_query = vectorstore_query["search_string"]
_filter = vectorstore_query["metadata_filter"]
docs = self.vectorstore.search(
new_query, self.search_type, filter=_filter, **self.search_kwargs
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
inputs = self.llm_chain.prep_inputs(query)
vectorstore_query = cast(dict, self.llm_chain.apredict_and_parse(**inputs))
new_query = vectorstore_query["query"]
_filter = vectorstore_query["filter"]
docs = await self.vectorstore.asearch(
new_query, self.search_type, filter=_filter, **self.search_kwargs
)
return docs
@classmethod
def from_vectorstore_info(
cls,
llm: BaseLLM,
vectorstore_info: VectorStoreExtendedInfo,
**kwargs: Any,
) -> "PineconeSelfQueryRetriever":
metadata_field_json = _format_metadata_field_info(
vectorstore_info.metadata_field_info
)
prompt_str = self_query_prompt.format(
format_instructions=pinecone_format_instructions,
example=pinecone_example,
docstore_description=vectorstore_info.description,
metadata_fields=metadata_field_json,
)
prompt = PromptTemplate(
input_variables=["question"],
template=prompt_str,
output_parser=PineconeSelfQueryOutputParser(),
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(
llm_chain=llm_chain,
vectorstore=vectorstore_info.vectorstore,
**kwargs,
)

View File

@@ -0,0 +1,108 @@
# flake8: noqa
from typing import Dict
from langchain.output_parsers.structured import parse_json_markdown
from langchain.schema import BaseOutputParser
pinecone_format_instructions = """RESPONSE FORMAT
----------------------------
When responding use a markdown code snippet formatted in the following schema:
```json
{{
"search_string": string \\ The string to compare to document contents,
"metadata_filter": {{
"<MetadataField>": {{
<Operator>: <Value>
}},
"<MetadataField>": {{
<Operator>: <Value>
}},
}} \\ The conditions on which to filter the metadata
}}
```
Filtering metadata supports the following operators:
$eq - Equal to (number, string, boolean)
$ne - Not equal to (number, string, boolean)
$gt - Greater than (number)
$gte - Greater than or equal to (number)
$lt - Less than (number)
$lte - Less than or equal to (number)
$in - In array (string or number)
$nin - Not in array (string or number)
NOTE that if you are not exactly sure how some string metadata values are formatted you can include multiple potentially matching values. For example if you're not sure how a string value is capitalized, you can check for equality with both the fully upper-cased or fully lower-cased versions of the value.
PLEASE REMEMBER that for some queries no metadata filters are needed. In these cases you should leave "metadata_filter" as an empty map."""
pinecone_example = """EXAMPLE
----------------------------
DOCUMENT STORE DESCRIPTION: News headlines from around the world
METADATA FIELDS: {{
"year": {{
"description": "The year the headline was published",
"type": "integer",
"example_values": [2022, 1997]
}},
"country": {{
"description": "The country of origin of the news media outlet",
"type": "string",
"example_values": ["Chile", "Japan", "Ghana"]
}},
"source": {{
"description": "The name of the news media outlet",
"type": "string",
"example_values": ["Wall Street Journal", "New York Times", "Axios"]
}}
}}
QUESTION: What was the sentiment of Mexican and Taiwanese media outlets regarding the 2024 trade deal between America and India?
DOCUMENT STORE QUERY:
```json
{{
"search_string": "Trade deal between America and India",
"metadata_filter": {{
"country": {{
"$in": ["Mexico", "United Mexican States", "mexico", "Taiwan", "Republic of China", "ROC", "taiwan"]
}},
"year": {{
"$gte": 2024
}}
}}
}}
```"""
self_query_prompt = """INSTRUCTIONS
----------------------------
You have access to a store of documents. Each document contains text and a key-value store of associated metadata. Given a user question, your job is to come up with a fully formed query to the document store that will return the most relevant documents.
A document store query consists of two components: a search string and a metadata filter. The search string is compared to the text contents of the stored documents. The metadata filter is used to filter out documents whose metadata does not match the given criteria.
{format_instructions}
{example}
Begin!
DOCUMENT STORE DESCRIPTION: {docstore_description}
METADATA FIELDS: {metadata_fields}
QUESTION: {{question}}
DOCUMENT STORE QUERY:"""
class PineconeSelfQueryOutputParser(BaseOutputParser[Dict]):
def get_format_instructions(self) -> str:
return pinecone_format_instructions
def parse(self, text: str) -> Dict:
expected_keys = ["search_string", "metadata_filter"]
parsed = parse_json_markdown(text, expected_keys)
if len(parsed["search_string"]) == 0:
parsed["search_string"] = " "
return parsed
@property
def _type(self) -> str:
return "pinecone_self_query_output_parser"

View File

@@ -75,6 +75,32 @@ class VectorStore(ABC):
metadatas = [doc.metadata for doc in documents]
return await self.aadd_texts(texts, metadatas, **kwargs)
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
elif search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
elif search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
@abstractmethod
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any

View File

@@ -243,7 +243,7 @@ class DeepLake(VectorStore):
self.ds.summary()
return ids
def search(
def _search_helper(
self,
query: Any[str, None] = None,
embedding: Any[float, None] = None,
@@ -366,7 +366,7 @@ class DeepLake(VectorStore):
Returns:
List of Documents most similar to the query vector.
"""
return self.search(query=query, k=k, **kwargs)
return self._search_helper(query=query, k=k, **kwargs)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
@@ -379,7 +379,7 @@ class DeepLake(VectorStore):
Returns:
List of Documents most similar to the query vector.
"""
return self.search(embedding=embedding, k=k, **kwargs)
return self._search_helper(embedding=embedding, k=k, **kwargs)
def similarity_search_with_score(
self,
@@ -401,7 +401,7 @@ class DeepLake(VectorStore):
List[Tuple[Document, float]]: List of documents most similar to the query
text with distance in float.
"""
return self.search(
return self._search_helper(
query=query,
k=k,
filter=filter,
@@ -431,7 +431,7 @@ class DeepLake(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
return self.search(
return self._search_helper(
embedding=embedding,
k=k,
fetch_k=fetch_k,
@@ -465,7 +465,7 @@ class DeepLake(VectorStore):
raise ValueError(
"For MMR search, you must specify an embedding function on" "creation."
)
return self.search(
return self._search_helper(
query=query,
k=k,
fetch_k=fetch_k,