Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
eb1cd3d66c eval 2023-03-09 16:44:30 -08:00
6 changed files with 532 additions and 1 deletions

View File

@@ -0,0 +1,205 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "63a6161b",
"metadata": {},
"outputs": [],
"source": [
"from langchain import OpenAI, SQLDatabase, SQLDatabaseChain"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "24f017da",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3e980729",
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "f8b4e54f",
"metadata": {},
"outputs": [],
"source": [
"questions = [\n",
" {\n",
" \"question\": \"How many employees are there?\",\n",
" \"answer\": \"8\"\n",
" },\n",
" {\n",
" \"question\": \"What are some example tracks by composer Johann Sebastian Bach?\",\n",
" \"answer\": \"'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 'Goldberg Variations': Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'\"\n",
" },\n",
" {\n",
" \"question\": \"What are some example tracks by Bach?\",\n",
" \"answer\": \"'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 'Goldberg Variations': Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'\"\n",
" },\n",
" {\n",
" \"question\": \"How many employees are also customers?\",\n",
" \"answer\": \"None\"\n",
" },\n",
" {\n",
" \"question\": \"Where is Mark Telus from?\",\n",
" \"answer\": \"Edmonton, Canada\"\n",
" },\n",
" {\n",
" \"question\": \"What is the most common genre of songs?\",\n",
" \"answer\": \"Rock\"\n",
" },\n",
" {\n",
" \"question\": \"What is the most common media type?\",\n",
" \"answer\": \"MPEG audio file\"\n",
" },\n",
" {\n",
" \"question\": \"What is the most common media type?\",\n",
" \"answer\": \"Purchased AAC audio file\"\n",
" },\n",
" {\n",
" \"question\": \"How many more Protected AAC audio files are there than Protected MPEG-4 video file?\",\n",
" \"answer\": \"23\"\n",
" },\n",
" {\n",
" \"question\": \"How many albums are there\",\n",
" \"answer\": \"347\"\n",
" }\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "5896eda7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(questions)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "21dc41ac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(1, 3034), (2, 237), (3, 214), (4, 7), (5, 11)]'"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(\"\"\"SELECT\n",
" MediaTypeID,\n",
" COUNT(*) AS `num`\n",
"FROM\n",
" Track\n",
"GROUP BY\n",
" MediaTypeID\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "115cd3da",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"''"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.get_table_info()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "659c8d20",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(347,)]'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(\"select count(*) from album\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b99a505",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,236 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"id": "d2ff9d0e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"sota_loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
"pg_loader = TextLoader(\"../../../../gpt_index/examples/paul_graham_essay/data/paul_graham_essay.txt\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "57794791",
"metadata": {},
"outputs": [],
"source": [
"from langchain.indexes import VectorstoreIndexCreator\n",
"from langchain.vectorstores import FAISS"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "5fa10ffb",
"metadata": {},
"outputs": [],
"source": [
"sota_index = VectorstoreIndexCreator(vectorstore_cls=FAISS).from_loaders([sota_loader])\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "34ceb9c6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"Using DuckDB in-memory for database. Data will be transient.\n"
]
}
],
"source": [
"pg_index = VectorstoreIndexCreator(vectorstore_kwargs={\"collection_name\": \"paul-graham\"}).from_loaders([pg_loader])\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "9c3d06e4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" The President nominated Circuit Court of Appeals Judge Ketanji Brown Jackson to serve on the United States Supreme Court. He said she is one of the nation's top legal minds and will continue Justice Breyer's legacy of excellence.\""
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sota_index.query(\"what did the president about kentaji brown jackson?\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "94be0f0f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" Kentaji Brown Jackson was not mentioned in the context, so I don't know.\""
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pg_index.query(\"what did the president about kentaji brown jackson?\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "fb7a1185",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.tools import BaseTool\n",
"from langchain.llms import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b853ce83",
"metadata": {},
"outputs": [],
"source": [
"tools = [\n",
" Tool(\n",
" name = \"State of Union QA System\",\n",
" func=sota_index.query,\n",
" description=\"useful for when you need to answer questions about the most recent state of the union address. Input should be a fully formed question.\"\n",
" ),\n",
" Tool(\n",
" name = \"Paul Graham QA System\",\n",
" func=pg_index.query,\n",
" description=\"useful for when you need to answer questions about Paul Graham. Input should be a fully formed question.\"\n",
" ),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a959727e",
"metadata": {},
"outputs": [],
"source": [
"agent = initialize_agent(tools, OpenAI(temperature=0), agent=\"zero-shot-react-description\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "20440754",
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "59d1547f",
"metadata": {},
"outputs": [],
"source": [
"with open(\"../../../notebooks/state_of_union_qa.json\") as f:\n",
" sota_qa = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c3c457df",
"metadata": {},
"outputs": [],
"source": [
"with open(\"../../../notebooks/paul_graham_qa.json\") as f:\n",
" pg_qa = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "36e1ddc2",
"metadata": {},
"outputs": [],
"source": [
"for d in sota_qa:\n",
" d['steps'] = [{\"tool\": \"State of Union QA System\"}, {\"tool_input\": d[\"question\"]}]\n",
"for d in pg_qa:\n",
" d['steps'] = [{\"tool\": \"Paul Graham QA System\"}, {\"tool_input\": d[\"question\"]}]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "59069433",
"metadata": {},
"outputs": [],
"source": [
"all_vectorstore_routing = sota_qa + pg_qa"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "157a27bb",
"metadata": {},
"outputs": [],
"source": [
"with open(\"vectorstore_sota_pg.json\", \"w\") as f:\n",
" json.dump(all_vectorstore_routing, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe86c9d2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,48 @@
from typing import Dict, List, Optional, Any
from langchain.chains.base import Chain
from pydantic import Field
import json
from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
from langchain.schema import BaseLanguageModel
from langchain.prompts.base import BasePromptTemplate
from langchain.chains.llm import LLMChain
class QAGenerationChain(Chain):
llm_chain: LLMChain
text_splitter: TextSplitter = Field(default=RecursiveCharacterTextSplitter(chunk_overlap=500))
input_key: str = "text"
output_key: str = "questions"
k: Optional[int] = None
@classmethod
def from_llm(cls, llm: BaseLanguageModel, prompt: Optional[BasePromptTemplate] = None, **kwargs: Any):
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
chain = LLMChain(llm=llm, prompt=_prompt)
return cls(llm_chain=chain, **kwargs)
@property
def _chain_type(self) -> str:
raise NotImplementedError
@property
def input_keys(self) -> List[str]:
return[ self.input_key]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
docs = self.text_splitter.create_documents([inputs[self.input_key]])
results = self.llm_chain.generate([{"text": d.page_content} for d in docs])
qa = [json.loads(res[0].text) for res in results.generations]
return {self.output_key: qa}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
raise NotImplementedError

View File

@@ -0,0 +1,41 @@
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
from langchain.prompts.prompt import PromptTemplate
templ1 = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
When coming up with this question/answer pair, you must respond in the following format:
```
{{
"question": "$YOUR_QUESTION_HERE",
"answer": "$THE_ANSWER_HERE"
}}
```
Everything between the ``` must be valid json.
"""
templ2 = """Please come up with a question/answer pair, in the specified JSON format, for the following text:
----------------
{text}"""
CHAT_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(templ1),
HumanMessagePromptTemplate.from_template(templ2)
])
templ = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
When coming up with this question/answer pair, you must respond in the following format:
```
{{
"question": "$YOUR_QUESTION_HERE",
"answer": "$THE_ANSWER_HERE"
}}
```
Everything between the ``` must be valid json.
Please come up with a question/answer pair, in the specified JSON format, for the following text:
----------------
{text}"""
PROMPT = PromptTemplate.from_template(templ)
PROMPT_SELECTOR = ConditionalPromptSelector(default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)])

View File

@@ -52,6 +52,7 @@ class VectorstoreIndexCreator(BaseModel):
vectorstore_cls: Type[VectorStore] = Chroma
embedding: Embeddings = Field(default_factory=OpenAIEmbeddings)
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter)
vectorstore_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
@@ -65,5 +66,5 @@ class VectorstoreIndexCreator(BaseModel):
for loader in loaders:
docs.extend(loader.load())
sub_docs = self.text_splitter.split_documents(docs)
vectorstore = self.vectorstore_cls.from_documents(sub_docs, self.embedding)
vectorstore = self.vectorstore_cls.from_documents(sub_docs, self.embedding, **self.vectorstore_kwargs)
return VectorStoreIndexWrapper(vectorstore=vectorstore)