Compare commits
14 Commits
v0.0.273
...
bagatur/gp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed8753e7ce | ||
|
|
0d01cede03 | ||
|
|
63921e327d | ||
|
|
aab01b55db | ||
|
|
0da5803f5a | ||
|
|
a28eea5767 | ||
|
|
fa0b8f3368 | ||
|
|
12a373810c | ||
|
|
d57d08fd01 | ||
|
|
4339d21cf1 | ||
|
|
1960ac8d25 | ||
|
|
2ab04a4e32 | ||
|
|
985873c497 | ||
|
|
709a67d9bf |
1396
docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb
Normal file
@@ -2,5 +2,5 @@
|
||||
|
||||
One of the key concerns with using LLMs is that they may generate harmful or unethical text. This is an area of active research in the field. Here we present some built-in chains inspired by this research, which are intended to make the outputs of LLMs safer.
|
||||
|
||||
- [Moderation chain](/docs/use_cases/safety/moderation): Explicitly check if any output text is harmful and flag it.
|
||||
- [Constitutional chain](/docs/use_cases/safety/constitutional_chain): Prompt the model with a set of principles which should guide it's behavior.
|
||||
- [Moderation chain](/docs/guides/safety/moderation): Explicitly check if any output text is harmful and flag it.
|
||||
- [Constitutional chain](/docs/guides/safety/constitutional_chain): Prompt the model with a set of principles which should guide it's behavior.
|
||||
|
||||
BIN
docs/docs_skeleton/static/img/ReAct.png
Normal file
|
After Width: | Height: | Size: 42 KiB |
BIN
docs/docs_skeleton/static/img/agents_use_case_1.png
Normal file
|
After Width: | Height: | Size: 236 KiB |
BIN
docs/docs_skeleton/static/img/agents_use_case_trace_1.png
Normal file
|
After Width: | Height: | Size: 74 KiB |
BIN
docs/docs_skeleton/static/img/agents_use_case_trace_2.png
Normal file
|
After Width: | Height: | Size: 166 KiB |
BIN
docs/docs_skeleton/static/img/agents_vs_chains.png
Normal file
|
After Width: | Height: | Size: 42 KiB |
BIN
docs/docs_skeleton/static/img/oai_function_agent.png
Normal file
|
After Width: | Height: | Size: 177 KiB |
@@ -8,7 +8,7 @@ Here's a few different tools and functionalities to aid in debugging.
|
||||
|
||||
## Tracing
|
||||
|
||||
Platforms with tracing capabilities like [LangSmith](/docs/guides/langsmith/) and [WandB](/docs/ecosystem/integrations/agent_with_wandb_tracing) are the most comprehensive solutions for debugging. These platforms make it easy to not only log and visualize LLM apps, but also to actively debug, test and refine them.
|
||||
Platforms with tracing capabilities like [LangSmith](/docs/guides/langsmith/) and [WandB](/docs/integrations/providers/wandb_tracing) are the most comprehensive solutions for debugging. These platforms make it easy to not only log and visualize LLM apps, but also to actively debug, test and refine them.
|
||||
|
||||
For anyone building production-grade LLM applications, we highly recommend using a platform like this.
|
||||
|
||||
|
||||
@@ -13,7 +13,10 @@
|
||||
"\n",
|
||||
"- smaller chunks: split a document into smaller chunks, and embed those (this is ParentDocumentRetriever)\n",
|
||||
"- summary: create a summary for each document, embed that along with (or instead of) the document\n",
|
||||
"- hypothetical questions: create hypothetical questions that each document would be appropriate to answer, embed those along with (or instead of) the document"
|
||||
"- hypothetical questions: create hypothetical questions that each document would be appropriate to answer, embed those along with (or instead of) the document\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Note that this also enables another method of adding embeddings - manually. This is great because you can explicitly add questions or queries that should lead to a document being recovered, giving you more control"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -106,7 +109,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"id": "5d23247d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -122,7 +125,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"id": "92ed5861",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -133,17 +136,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"id": "8afed60c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Document(page_content='Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.', metadata={'doc_id': 'b4ca7817-e3fe-4103-ac81-574fb41439ef', 'source': '../../state_of_the_union.txt'})"
|
||||
"Document(page_content='Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.', metadata={'doc_id': '10e9cbc0-4ba5-4d79-a09b-c033d1ba7b01', 'source': '../../state_of_the_union.txt'})"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -155,7 +158,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 9,
|
||||
"id": "3c9017f1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -165,7 +168,7 @@
|
||||
"9874"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -187,7 +190,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 10,
|
||||
"id": "1433dff4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -201,7 +204,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 11,
|
||||
"id": "35b30390",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -216,17 +219,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 12,
|
||||
"id": "41a2a738",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"summaries = [chain.invoke(d) for d in docs]"
|
||||
"summaries = chain.batch(docs, {\"max_concurrency\": 5})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 13,
|
||||
"id": "7ac5e4b1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -250,7 +253,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 15,
|
||||
"id": "0d93309f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -260,7 +263,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 16,
|
||||
"id": "6d5edf0d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -271,7 +274,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 17,
|
||||
"id": "862ae920",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # We can also add the original chunks to the vectorstore if we so want\n",
|
||||
"# for i, doc in enumerate(docs):\n",
|
||||
"# doc.metadata[id_key] = doc_ids[i]\n",
|
||||
"# retriever.vectorstore.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "299232d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -281,17 +297,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 19,
|
||||
"id": "10e404c0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Document(page_content='The document discusses various topics and proposals put forth by the President in a State of the Union address. These include the nomination of a judge for the Supreme Court, securing the border and fixing the immigration system, advancing liberty and justice for women and LGBTQ+ individuals, passing bipartisan legislation, addressing the opioid epidemic and mental health issues, supporting veterans, and ending cancer. The President expresses optimism about the future of the country and emphasizes the strength of the American people.', metadata={'doc_id': '8c7a707d-615d-42d5-919d-bc5178dd1ae4'})"
|
||||
"Document(page_content=\"The document is a transcript of a speech given by the President of the United States. The President discusses several important issues and initiatives, including the nomination of a Supreme Court Justice, border security and immigration reform, protecting women's rights, advancing LGBTQ+ equality, bipartisan legislation, addressing the opioid epidemic and mental health, supporting veterans, investigating the health effects of burn pits on military personnel, ending cancer, and the strength and resilience of the American people.\", metadata={'doc_id': '79fa2e9f-28d9-4372-8af3-2caf4f1de312'})"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -302,7 +318,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 20,
|
||||
"id": "e4cce5c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -312,7 +328,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 21,
|
||||
"id": "c8570dbb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -322,7 +338,7 @@
|
||||
"9194"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -340,6 +356,203 @@
|
||||
"\n",
|
||||
"An LLM can also be used to generate a list of hypothetical questions that could be asked of a particular document. These questions can then be embedded"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "5219b085",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"functions = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"hypothetical_questions\",\n",
|
||||
" \"description\": \"Generate hypothetical questions\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"questions\": {\n",
|
||||
" \"type\": \"array\",\n",
|
||||
" \"items\": {\n",
|
||||
" \"type\": \"string\"\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"required\": [\"questions\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" ]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "523deb92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n",
|
||||
"chain = (\n",
|
||||
" {\"doc\": lambda x: x.page_content}\n",
|
||||
" # Only asking for 3 hypothetical questions, but this could be adjusted\n",
|
||||
" | ChatPromptTemplate.from_template(\"Generate a list of 3 hypothetical questions that the below document could be used to answer:\\n\\n{doc}\")\n",
|
||||
" | ChatOpenAI(max_retries=0, model=\"gpt-4\").bind(functions=functions, function_call={\"name\": \"hypothetical_questions\"})\n",
|
||||
" | JsonKeyOutputFunctionsParser(key_name=\"questions\")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "11d30554",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[\"What was the author's initial impression of philosophy as a field of study, and how did it change when they got to college?\",\n",
|
||||
" 'Why did the author decide to switch their focus to Artificial Intelligence (AI)?',\n",
|
||||
" \"What led to the author's disillusionment with the field of AI as it was practiced at the time?\"]"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke(docs[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"id": "3eb2e48c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hypothetical_questions = chain.batch(docs, {\"max_concurrency\": 5})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"id": "b2cd6e75",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The vectorstore to use to index the child chunks\n",
|
||||
"vectorstore = Chroma(\n",
|
||||
" collection_name=\"hypo-questions\",\n",
|
||||
" embedding_function=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"# The storage layer for the parent documents\n",
|
||||
"store = InMemoryStore()\n",
|
||||
"id_key = \"doc_id\"\n",
|
||||
"# The retriever (empty to start)\n",
|
||||
"retriever = MultiVectorRetriever(\n",
|
||||
" vectorstore=vectorstore, \n",
|
||||
" docstore=store, \n",
|
||||
" id_key=id_key,\n",
|
||||
")\n",
|
||||
"doc_ids = [str(uuid.uuid4()) for _ in docs]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 68,
|
||||
"id": "18831b3b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question_docs = []\n",
|
||||
"for i, question_list in enumerate(hypothetical_questions):\n",
|
||||
" question_docs.extend([Document(page_content=s,metadata={id_key: doc_ids[i]}) for s in question_list])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"id": "224b24c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever.vectorstore.add_documents(question_docs)\n",
|
||||
"retriever.docstore.mset(list(zip(doc_ids, docs)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"id": "7b442b90",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sub_docs = vectorstore.similarity_search(\"justice breyer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"id": "089b5ad0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content=\"What is the President's stance on immigration reform?\", metadata={'doc_id': '505d73e3-8350-46ec-a58e-3af032f04ab3'}),\n",
|
||||
" Document(page_content=\"What is the President's stance on immigration reform?\", metadata={'doc_id': '1c9618f0-7660-4b4f-a37c-509cbbbf6dba'}),\n",
|
||||
" Document(page_content=\"What is the President's stance on immigration reform?\", metadata={'doc_id': '82c08209-b904-46a8-9532-edd2380950b7'}),\n",
|
||||
" Document(page_content='What measures is the President proposing to protect the rights of LGBTQ+ Americans?', metadata={'doc_id': '82c08209-b904-46a8-9532-edd2380950b7'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 71,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sub_docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"id": "7594b24e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retrieved_docs = retriever.get_relevant_documents(\"justice breyer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 73,
|
||||
"id": "4c120c65",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"9194"
|
||||
]
|
||||
},
|
||||
"execution_count": 73,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(retrieved_docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "616cfeeb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -203,7 +203,7 @@
|
||||
"prompt = ChatPromptTemplate.from_messages([\n",
|
||||
" SystemMessage(content=\"You are a chatbot having a conversation with a human.\"), # The persistent system prompt\n",
|
||||
" MessagesPlaceholder(variable_name=\"chat_history\"), # Where the memory will be stored.\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{human_input}\"), # Where the human input will injectd\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{human_input}\"), # Where the human input will injected\n",
|
||||
"])\n",
|
||||
" \n",
|
||||
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)"
|
||||
|
||||
@@ -1,565 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "517a9fd4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# BabyAGI User Guide\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to implement [BabyAGI](https://github.com/yoheinakajima/babyagi/tree/main) by [Yohei Nakajima](https://twitter.com/yoheinakajima). BabyAGI is an AI agent that can generate and pretend to execute tasks based on a given objective.\n",
|
||||
"\n",
|
||||
"This guide will help you understand the components to create your own recursive agents.\n",
|
||||
"\n",
|
||||
"Although BabyAGI uses specific vectorstores/model providers (Pinecone, OpenAI), one of the benefits of implementing it with LangChain is that you can easily swap those out for different options. In this implementation we use a FAISS vectorstore (because it runs locally and is free)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "556af556",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install and Import Required Modules"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 116,
|
||||
"id": "c8a354b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from collections import deque\n",
|
||||
"from typing import Dict, List, Optional, Any\n",
|
||||
"\n",
|
||||
"from langchain import LLMChain, OpenAI, PromptTemplate\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.llms import BaseLLM\n",
|
||||
"from langchain.vectorstores.base import VectorStore\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from langchain.chains.base import Chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "09f70772",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Connect to the Vector Store\n",
|
||||
"\n",
|
||||
"Depending on what vectorstore you use, this step may look different."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"id": "794045d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"from langchain.docstore import InMemoryDocstore"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"id": "6e0305eb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define your embedding model\n",
|
||||
"embeddings_model = OpenAIEmbeddings()\n",
|
||||
"# Initialize the vectorstore as empty\n",
|
||||
"import faiss\n",
|
||||
"\n",
|
||||
"embedding_size = 1536\n",
|
||||
"index = faiss.IndexFlatL2(embedding_size)\n",
|
||||
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0f3b72bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define the Chains\n",
|
||||
"\n",
|
||||
"BabyAGI relies on three LLM chains:\n",
|
||||
"- Task creation chain to select new tasks to add to the list\n",
|
||||
"- Task prioritization chain to re-prioritize tasks\n",
|
||||
"- Execution Chain to execute the tasks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 73,
|
||||
"id": "bf4bd5cd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TaskCreationChain(LLMChain):\n",
|
||||
" \"\"\"Chain to generates tasks.\"\"\"\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
|
||||
" \"\"\"Get the response parser.\"\"\"\n",
|
||||
" task_creation_template = (\n",
|
||||
" \"You are a task creation AI that uses the result of an execution agent\"\n",
|
||||
" \" to create new tasks with the following objective: {objective},\"\n",
|
||||
" \" The last completed task has the result: {result}.\"\n",
|
||||
" \" This result was based on this task description: {task_description}.\"\n",
|
||||
" \" These are incomplete tasks: {incomplete_tasks}.\"\n",
|
||||
" \" Based on the result, create new tasks to be completed\"\n",
|
||||
" \" by the AI system that do not overlap with incomplete tasks.\"\n",
|
||||
" \" Return the tasks as an array.\"\n",
|
||||
" )\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" template=task_creation_template,\n",
|
||||
" input_variables=[\n",
|
||||
" \"result\",\n",
|
||||
" \"task_description\",\n",
|
||||
" \"incomplete_tasks\",\n",
|
||||
" \"objective\",\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"id": "b6488ffe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TaskPrioritizationChain(LLMChain):\n",
|
||||
" \"\"\"Chain to prioritize tasks.\"\"\"\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
|
||||
" \"\"\"Get the response parser.\"\"\"\n",
|
||||
" task_prioritization_template = (\n",
|
||||
" \"You are a task prioritization AI tasked with cleaning the formatting of and reprioritizing\"\n",
|
||||
" \" the following tasks: {task_names}.\"\n",
|
||||
" \" Consider the ultimate objective of your team: {objective}.\"\n",
|
||||
" \" Do not remove any tasks. Return the result as a numbered list, like:\"\n",
|
||||
" \" #. First task\"\n",
|
||||
" \" #. Second task\"\n",
|
||||
" \" Start the task list with number {next_task_id}.\"\n",
|
||||
" )\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" template=task_prioritization_template,\n",
|
||||
" input_variables=[\"task_names\", \"next_task_id\", \"objective\"],\n",
|
||||
" )\n",
|
||||
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 84,
|
||||
"id": "b43cd580",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ExecutionChain(LLMChain):\n",
|
||||
" \"\"\"Chain to execute tasks.\"\"\"\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
|
||||
" \"\"\"Get the response parser.\"\"\"\n",
|
||||
" execution_template = (\n",
|
||||
" \"You are an AI who performs one task based on the following objective: {objective}.\"\n",
|
||||
" \" Take into account these previously completed tasks: {context}.\"\n",
|
||||
" \" Your task: {task}.\"\n",
|
||||
" \" Response:\"\n",
|
||||
" )\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" template=execution_template,\n",
|
||||
" input_variables=[\"objective\", \"context\", \"task\"],\n",
|
||||
" )\n",
|
||||
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3ad996c5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define the BabyAGI Controller\n",
|
||||
"\n",
|
||||
"BabyAGI composes the chains defined above in a (potentially-)infinite loop."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 85,
|
||||
"id": "0ada0636",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_next_task(\n",
|
||||
" task_creation_chain: LLMChain,\n",
|
||||
" result: Dict,\n",
|
||||
" task_description: str,\n",
|
||||
" task_list: List[str],\n",
|
||||
" objective: str,\n",
|
||||
") -> List[Dict]:\n",
|
||||
" \"\"\"Get the next task.\"\"\"\n",
|
||||
" incomplete_tasks = \", \".join(task_list)\n",
|
||||
" response = task_creation_chain.run(\n",
|
||||
" result=result,\n",
|
||||
" task_description=task_description,\n",
|
||||
" incomplete_tasks=incomplete_tasks,\n",
|
||||
" objective=objective,\n",
|
||||
" )\n",
|
||||
" new_tasks = response.split(\"\\n\")\n",
|
||||
" return [{\"task_name\": task_name} for task_name in new_tasks if task_name.strip()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 86,
|
||||
"id": "d35250ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prioritize_tasks(\n",
|
||||
" task_prioritization_chain: LLMChain,\n",
|
||||
" this_task_id: int,\n",
|
||||
" task_list: List[Dict],\n",
|
||||
" objective: str,\n",
|
||||
") -> List[Dict]:\n",
|
||||
" \"\"\"Prioritize tasks.\"\"\"\n",
|
||||
" task_names = [t[\"task_name\"] for t in task_list]\n",
|
||||
" next_task_id = int(this_task_id) + 1\n",
|
||||
" response = task_prioritization_chain.run(\n",
|
||||
" task_names=task_names, next_task_id=next_task_id, objective=objective\n",
|
||||
" )\n",
|
||||
" new_tasks = response.split(\"\\n\")\n",
|
||||
" prioritized_task_list = []\n",
|
||||
" for task_string in new_tasks:\n",
|
||||
" if not task_string.strip():\n",
|
||||
" continue\n",
|
||||
" task_parts = task_string.strip().split(\".\", 1)\n",
|
||||
" if len(task_parts) == 2:\n",
|
||||
" task_id = task_parts[0].strip()\n",
|
||||
" task_name = task_parts[1].strip()\n",
|
||||
" prioritized_task_list.append({\"task_id\": task_id, \"task_name\": task_name})\n",
|
||||
" return prioritized_task_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 87,
|
||||
"id": "e3f1840c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _get_top_tasks(vectorstore, query: str, k: int) -> List[str]:\n",
|
||||
" \"\"\"Get the top k tasks based on the query.\"\"\"\n",
|
||||
" results = vectorstore.similarity_search_with_score(query, k=k)\n",
|
||||
" if not results:\n",
|
||||
" return []\n",
|
||||
" sorted_results, _ = zip(*sorted(results, key=lambda x: x[1], reverse=True))\n",
|
||||
" return [str(item.metadata[\"task\"]) for item in sorted_results]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def execute_task(\n",
|
||||
" vectorstore, execution_chain: LLMChain, objective: str, task: str, k: int = 5\n",
|
||||
") -> str:\n",
|
||||
" \"\"\"Execute a task.\"\"\"\n",
|
||||
" context = _get_top_tasks(vectorstore, query=objective, k=k)\n",
|
||||
" return execution_chain.run(objective=objective, context=context, task=task)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 137,
|
||||
"id": "1e978938",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BabyAGI(Chain, BaseModel):\n",
|
||||
" \"\"\"Controller model for the BabyAGI agent.\"\"\"\n",
|
||||
"\n",
|
||||
" task_list: deque = Field(default_factory=deque)\n",
|
||||
" task_creation_chain: TaskCreationChain = Field(...)\n",
|
||||
" task_prioritization_chain: TaskPrioritizationChain = Field(...)\n",
|
||||
" execution_chain: ExecutionChain = Field(...)\n",
|
||||
" task_id_counter: int = Field(1)\n",
|
||||
" vectorstore: VectorStore = Field(init=False)\n",
|
||||
" max_iterations: Optional[int] = None\n",
|
||||
"\n",
|
||||
" class Config:\n",
|
||||
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
|
||||
"\n",
|
||||
" arbitrary_types_allowed = True\n",
|
||||
"\n",
|
||||
" def add_task(self, task: Dict):\n",
|
||||
" self.task_list.append(task)\n",
|
||||
"\n",
|
||||
" def print_task_list(self):\n",
|
||||
" print(\"\\033[95m\\033[1m\" + \"\\n*****TASK LIST*****\\n\" + \"\\033[0m\\033[0m\")\n",
|
||||
" for t in self.task_list:\n",
|
||||
" print(str(t[\"task_id\"]) + \": \" + t[\"task_name\"])\n",
|
||||
"\n",
|
||||
" def print_next_task(self, task: Dict):\n",
|
||||
" print(\"\\033[92m\\033[1m\" + \"\\n*****NEXT TASK*****\\n\" + \"\\033[0m\\033[0m\")\n",
|
||||
" print(str(task[\"task_id\"]) + \": \" + task[\"task_name\"])\n",
|
||||
"\n",
|
||||
" def print_task_result(self, result: str):\n",
|
||||
" print(\"\\033[93m\\033[1m\" + \"\\n*****TASK RESULT*****\\n\" + \"\\033[0m\\033[0m\")\n",
|
||||
" print(result)\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def input_keys(self) -> List[str]:\n",
|
||||
" return [\"objective\"]\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def output_keys(self) -> List[str]:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:\n",
|
||||
" \"\"\"Run the agent.\"\"\"\n",
|
||||
" objective = inputs[\"objective\"]\n",
|
||||
" first_task = inputs.get(\"first_task\", \"Make a todo list\")\n",
|
||||
" self.add_task({\"task_id\": 1, \"task_name\": first_task})\n",
|
||||
" num_iters = 0\n",
|
||||
" while True:\n",
|
||||
" if self.task_list:\n",
|
||||
" self.print_task_list()\n",
|
||||
"\n",
|
||||
" # Step 1: Pull the first task\n",
|
||||
" task = self.task_list.popleft()\n",
|
||||
" self.print_next_task(task)\n",
|
||||
"\n",
|
||||
" # Step 2: Execute the task\n",
|
||||
" result = execute_task(\n",
|
||||
" self.vectorstore, self.execution_chain, objective, task[\"task_name\"]\n",
|
||||
" )\n",
|
||||
" this_task_id = int(task[\"task_id\"])\n",
|
||||
" self.print_task_result(result)\n",
|
||||
"\n",
|
||||
" # Step 3: Store the result in Pinecone\n",
|
||||
" result_id = f\"result_{task['task_id']}_{num_iters}\"\n",
|
||||
" self.vectorstore.add_texts(\n",
|
||||
" texts=[result],\n",
|
||||
" metadatas=[{\"task\": task[\"task_name\"]}],\n",
|
||||
" ids=[result_id],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Step 4: Create new tasks and reprioritize task list\n",
|
||||
" new_tasks = get_next_task(\n",
|
||||
" self.task_creation_chain,\n",
|
||||
" result,\n",
|
||||
" task[\"task_name\"],\n",
|
||||
" [t[\"task_name\"] for t in self.task_list],\n",
|
||||
" objective,\n",
|
||||
" )\n",
|
||||
" for new_task in new_tasks:\n",
|
||||
" self.task_id_counter += 1\n",
|
||||
" new_task.update({\"task_id\": self.task_id_counter})\n",
|
||||
" self.add_task(new_task)\n",
|
||||
" self.task_list = deque(\n",
|
||||
" prioritize_tasks(\n",
|
||||
" self.task_prioritization_chain,\n",
|
||||
" this_task_id,\n",
|
||||
" list(self.task_list),\n",
|
||||
" objective,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" num_iters += 1\n",
|
||||
" if self.max_iterations is not None and num_iters == self.max_iterations:\n",
|
||||
" print(\n",
|
||||
" \"\\033[91m\\033[1m\" + \"\\n*****TASK ENDING*****\\n\" + \"\\033[0m\\033[0m\"\n",
|
||||
" )\n",
|
||||
" break\n",
|
||||
" return {}\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def from_llm(\n",
|
||||
" cls, llm: BaseLLM, vectorstore: VectorStore, verbose: bool = False, **kwargs\n",
|
||||
" ) -> \"BabyAGI\":\n",
|
||||
" \"\"\"Initialize the BabyAGI Controller.\"\"\"\n",
|
||||
" task_creation_chain = TaskCreationChain.from_llm(llm, verbose=verbose)\n",
|
||||
" task_prioritization_chain = TaskPrioritizationChain.from_llm(\n",
|
||||
" llm, verbose=verbose\n",
|
||||
" )\n",
|
||||
" execution_chain = ExecutionChain.from_llm(llm, verbose=verbose)\n",
|
||||
" return cls(\n",
|
||||
" task_creation_chain=task_creation_chain,\n",
|
||||
" task_prioritization_chain=task_prioritization_chain,\n",
|
||||
" execution_chain=execution_chain,\n",
|
||||
" vectorstore=vectorstore,\n",
|
||||
" **kwargs,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "05ba762e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Run the BabyAGI\n",
|
||||
"\n",
|
||||
"Now it's time to create the BabyAGI controller and watch it try to accomplish your objective."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 138,
|
||||
"id": "3d220b69",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"OBJECTIVE = \"Write a weather report for SF today\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 139,
|
||||
"id": "8a8e5543",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 140,
|
||||
"id": "3d69899b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Logging of LLMChains\n",
|
||||
"verbose = False\n",
|
||||
"# If None, will keep on going forever\n",
|
||||
"max_iterations: Optional[int] = 3\n",
|
||||
"baby_agi = BabyAGI.from_llm(\n",
|
||||
" llm=llm, vectorstore=vectorstore, verbose=verbose, max_iterations=max_iterations\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 141,
|
||||
"id": "f7957b51",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[95m\u001b[1m\n",
|
||||
"*****TASK LIST*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"1: Make a todo list\n",
|
||||
"\u001b[92m\u001b[1m\n",
|
||||
"*****NEXT TASK*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"1: Make a todo list\n",
|
||||
"\u001b[93m\u001b[1m\n",
|
||||
"*****TASK RESULT*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"1. Check the temperature range for the day.\n",
|
||||
"2. Gather temperature data for SF today.\n",
|
||||
"3. Analyze the temperature data and create a weather report.\n",
|
||||
"4. Publish the weather report.\n",
|
||||
"\u001b[95m\u001b[1m\n",
|
||||
"*****TASK LIST*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"2: Gather data on the expected temperature range for the day.\n",
|
||||
"3: Collect data on the expected precipitation for the day.\n",
|
||||
"4: Analyze the data and create a weather report.\n",
|
||||
"5: Check the current weather conditions in SF.\n",
|
||||
"6: Publish the weather report.\n",
|
||||
"\u001b[92m\u001b[1m\n",
|
||||
"*****NEXT TASK*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"2: Gather data on the expected temperature range for the day.\n",
|
||||
"\u001b[93m\u001b[1m\n",
|
||||
"*****TASK RESULT*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"I have gathered data on the expected temperature range for the day in San Francisco. The forecast is for temperatures to range from a low of 55 degrees Fahrenheit to a high of 68 degrees Fahrenheit.\n",
|
||||
"\u001b[95m\u001b[1m\n",
|
||||
"*****TASK LIST*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"3: Check the current weather conditions in SF.\n",
|
||||
"4: Calculate the average temperature for the day in San Francisco.\n",
|
||||
"5: Determine the probability of precipitation for the day in San Francisco.\n",
|
||||
"6: Identify any potential weather warnings or advisories for the day in San Francisco.\n",
|
||||
"7: Research any historical weather patterns for the day in San Francisco.\n",
|
||||
"8: Compare the expected temperature range to the historical average for the day in San Francisco.\n",
|
||||
"9: Collect data on the expected precipitation for the day.\n",
|
||||
"10: Analyze the data and create a weather report.\n",
|
||||
"11: Publish the weather report.\n",
|
||||
"\u001b[92m\u001b[1m\n",
|
||||
"*****NEXT TASK*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"3: Check the current weather conditions in SF.\n",
|
||||
"\u001b[93m\u001b[1m\n",
|
||||
"*****TASK RESULT*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"I am checking the current weather conditions in SF. According to the data I have gathered, the temperature in SF today is currently around 65 degrees Fahrenheit with clear skies. The temperature range for the day is expected to be between 60 and 70 degrees Fahrenheit.\n",
|
||||
"\u001b[91m\u001b[1m\n",
|
||||
"*****TASK ENDING*****\n",
|
||||
"\u001b[0m\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'objective': 'Write a weather report for SF today'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 141,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"baby_agi({\"objective\": OBJECTIVE})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "898a210b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,647 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "517a9fd4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# BabyAGI with Tools\n",
|
||||
"\n",
|
||||
"This notebook builds on top of [baby agi](baby_agi.html), but shows how you can swap out the execution chain. The previous execution chain was just an LLM which made stuff up. By swapping it out with an agent that has access to tools, we can hopefully get real reliable information"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "556af556",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install and Import Required Modules"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c8a354b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from collections import deque\n",
|
||||
"from typing import Dict, List, Optional, Any\n",
|
||||
"\n",
|
||||
"from langchain import LLMChain, OpenAI, PromptTemplate\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.llms import BaseLLM\n",
|
||||
"from langchain.vectorstores.base import VectorStore\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from langchain.chains.base import Chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "09f70772",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Connect to the Vector Store\n",
|
||||
"\n",
|
||||
"Depending on what vectorstore you use, this step may look different."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "794045d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install faiss-cpu > /dev/null\n",
|
||||
"%pip install google-search-results > /dev/null\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"from langchain.docstore import InMemoryDocstore"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6e0305eb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define your embedding model\n",
|
||||
"embeddings_model = OpenAIEmbeddings()\n",
|
||||
"# Initialize the vectorstore as empty\n",
|
||||
"import faiss\n",
|
||||
"\n",
|
||||
"embedding_size = 1536\n",
|
||||
"index = faiss.IndexFlatL2(embedding_size)\n",
|
||||
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0f3b72bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define the Chains\n",
|
||||
"\n",
|
||||
"BabyAGI relies on three LLM chains:\n",
|
||||
"- Task creation chain to select new tasks to add to the list\n",
|
||||
"- Task prioritization chain to re-prioritize tasks\n",
|
||||
"- Execution Chain to execute the tasks\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"NOTE: in this notebook, the Execution chain will now be an agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "bf4bd5cd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TaskCreationChain(LLMChain):\n",
|
||||
" \"\"\"Chain to generates tasks.\"\"\"\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
|
||||
" \"\"\"Get the response parser.\"\"\"\n",
|
||||
" task_creation_template = (\n",
|
||||
" \"You are an task creation AI that uses the result of an execution agent\"\n",
|
||||
" \" to create new tasks with the following objective: {objective},\"\n",
|
||||
" \" The last completed task has the result: {result}.\"\n",
|
||||
" \" This result was based on this task description: {task_description}.\"\n",
|
||||
" \" These are incomplete tasks: {incomplete_tasks}.\"\n",
|
||||
" \" Based on the result, create new tasks to be completed\"\n",
|
||||
" \" by the AI system that do not overlap with incomplete tasks.\"\n",
|
||||
" \" Return the tasks as an array.\"\n",
|
||||
" )\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" template=task_creation_template,\n",
|
||||
" input_variables=[\n",
|
||||
" \"result\",\n",
|
||||
" \"task_description\",\n",
|
||||
" \"incomplete_tasks\",\n",
|
||||
" \"objective\",\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "b6488ffe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TaskPrioritizationChain(LLMChain):\n",
|
||||
" \"\"\"Chain to prioritize tasks.\"\"\"\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
|
||||
" \"\"\"Get the response parser.\"\"\"\n",
|
||||
" task_prioritization_template = (\n",
|
||||
" \"You are an task prioritization AI tasked with cleaning the formatting of and reprioritizing\"\n",
|
||||
" \" the following tasks: {task_names}.\"\n",
|
||||
" \" Consider the ultimate objective of your team: {objective}.\"\n",
|
||||
" \" Do not remove any tasks. Return the result as a numbered list, like:\"\n",
|
||||
" \" #. First task\"\n",
|
||||
" \" #. Second task\"\n",
|
||||
" \" Start the task list with number {next_task_id}.\"\n",
|
||||
" )\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" template=task_prioritization_template,\n",
|
||||
" input_variables=[\"task_names\", \"next_task_id\", \"objective\"],\n",
|
||||
" )\n",
|
||||
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"id": "b43cd580",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
|
||||
"from langchain import OpenAI, SerpAPIWrapper, LLMChain\n",
|
||||
"\n",
|
||||
"todo_prompt = PromptTemplate.from_template(\n",
|
||||
" \"You are a planner who is an expert at coming up with a todo list for a given objective. Come up with a todo list for this objective: {objective}\"\n",
|
||||
")\n",
|
||||
"todo_chain = LLMChain(llm=OpenAI(temperature=0), prompt=todo_prompt)\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Search\",\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to answer questions about current events\",\n",
|
||||
" ),\n",
|
||||
" Tool(\n",
|
||||
" name=\"TODO\",\n",
|
||||
" func=todo_chain.run,\n",
|
||||
" description=\"useful for when you need to come up with todo lists. Input: an objective to create a todo list for. Output: a todo list for that objective. Please be very clear what the objective is!\",\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"prefix = \"\"\"You are an AI who performs one task based on the following objective: {objective}. Take into account these previously completed tasks: {context}.\"\"\"\n",
|
||||
"suffix = \"\"\"Question: {task}\n",
|
||||
"{agent_scratchpad}\"\"\"\n",
|
||||
"prompt = ZeroShotAgent.create_prompt(\n",
|
||||
" tools,\n",
|
||||
" prefix=prefix,\n",
|
||||
" suffix=suffix,\n",
|
||||
" input_variables=[\"objective\", \"task\", \"context\", \"agent_scratchpad\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3ad996c5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define the BabyAGI Controller\n",
|
||||
"\n",
|
||||
"BabyAGI composes the chains defined above in a (potentially-)infinite loop."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"id": "0ada0636",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_next_task(\n",
|
||||
" task_creation_chain: LLMChain,\n",
|
||||
" result: Dict,\n",
|
||||
" task_description: str,\n",
|
||||
" task_list: List[str],\n",
|
||||
" objective: str,\n",
|
||||
") -> List[Dict]:\n",
|
||||
" \"\"\"Get the next task.\"\"\"\n",
|
||||
" incomplete_tasks = \", \".join(task_list)\n",
|
||||
" response = task_creation_chain.run(\n",
|
||||
" result=result,\n",
|
||||
" task_description=task_description,\n",
|
||||
" incomplete_tasks=incomplete_tasks,\n",
|
||||
" objective=objective,\n",
|
||||
" )\n",
|
||||
" new_tasks = response.split(\"\\n\")\n",
|
||||
" return [{\"task_name\": task_name} for task_name in new_tasks if task_name.strip()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"id": "d35250ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prioritize_tasks(\n",
|
||||
" task_prioritization_chain: LLMChain,\n",
|
||||
" this_task_id: int,\n",
|
||||
" task_list: List[Dict],\n",
|
||||
" objective: str,\n",
|
||||
") -> List[Dict]:\n",
|
||||
" \"\"\"Prioritize tasks.\"\"\"\n",
|
||||
" task_names = [t[\"task_name\"] for t in task_list]\n",
|
||||
" next_task_id = int(this_task_id) + 1\n",
|
||||
" response = task_prioritization_chain.run(\n",
|
||||
" task_names=task_names, next_task_id=next_task_id, objective=objective\n",
|
||||
" )\n",
|
||||
" new_tasks = response.split(\"\\n\")\n",
|
||||
" prioritized_task_list = []\n",
|
||||
" for task_string in new_tasks:\n",
|
||||
" if not task_string.strip():\n",
|
||||
" continue\n",
|
||||
" task_parts = task_string.strip().split(\".\", 1)\n",
|
||||
" if len(task_parts) == 2:\n",
|
||||
" task_id = task_parts[0].strip()\n",
|
||||
" task_name = task_parts[1].strip()\n",
|
||||
" prioritized_task_list.append({\"task_id\": task_id, \"task_name\": task_name})\n",
|
||||
" return prioritized_task_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"id": "e3f1840c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _get_top_tasks(vectorstore, query: str, k: int) -> List[str]:\n",
|
||||
" \"\"\"Get the top k tasks based on the query.\"\"\"\n",
|
||||
" results = vectorstore.similarity_search_with_score(query, k=k)\n",
|
||||
" if not results:\n",
|
||||
" return []\n",
|
||||
" sorted_results, _ = zip(*sorted(results, key=lambda x: x[1], reverse=True))\n",
|
||||
" return [str(item.metadata[\"task\"]) for item in sorted_results]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def execute_task(\n",
|
||||
" vectorstore, execution_chain: LLMChain, objective: str, task: str, k: int = 5\n",
|
||||
") -> str:\n",
|
||||
" \"\"\"Execute a task.\"\"\"\n",
|
||||
" context = _get_top_tasks(vectorstore, query=objective, k=k)\n",
|
||||
" return execution_chain.run(objective=objective, context=context, task=task)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"id": "1e978938",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BabyAGI(Chain, BaseModel):\n",
|
||||
" \"\"\"Controller model for the BabyAGI agent.\"\"\"\n",
|
||||
"\n",
|
||||
" task_list: deque = Field(default_factory=deque)\n",
|
||||
" task_creation_chain: TaskCreationChain = Field(...)\n",
|
||||
" task_prioritization_chain: TaskPrioritizationChain = Field(...)\n",
|
||||
" execution_chain: AgentExecutor = Field(...)\n",
|
||||
" task_id_counter: int = Field(1)\n",
|
||||
" vectorstore: VectorStore = Field(init=False)\n",
|
||||
" max_iterations: Optional[int] = None\n",
|
||||
"\n",
|
||||
" class Config:\n",
|
||||
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
|
||||
"\n",
|
||||
" arbitrary_types_allowed = True\n",
|
||||
"\n",
|
||||
" def add_task(self, task: Dict):\n",
|
||||
" self.task_list.append(task)\n",
|
||||
"\n",
|
||||
" def print_task_list(self):\n",
|
||||
" print(\"\\033[95m\\033[1m\" + \"\\n*****TASK LIST*****\\n\" + \"\\033[0m\\033[0m\")\n",
|
||||
" for t in self.task_list:\n",
|
||||
" print(str(t[\"task_id\"]) + \": \" + t[\"task_name\"])\n",
|
||||
"\n",
|
||||
" def print_next_task(self, task: Dict):\n",
|
||||
" print(\"\\033[92m\\033[1m\" + \"\\n*****NEXT TASK*****\\n\" + \"\\033[0m\\033[0m\")\n",
|
||||
" print(str(task[\"task_id\"]) + \": \" + task[\"task_name\"])\n",
|
||||
"\n",
|
||||
" def print_task_result(self, result: str):\n",
|
||||
" print(\"\\033[93m\\033[1m\" + \"\\n*****TASK RESULT*****\\n\" + \"\\033[0m\\033[0m\")\n",
|
||||
" print(result)\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def input_keys(self) -> List[str]:\n",
|
||||
" return [\"objective\"]\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def output_keys(self) -> List[str]:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:\n",
|
||||
" \"\"\"Run the agent.\"\"\"\n",
|
||||
" objective = inputs[\"objective\"]\n",
|
||||
" first_task = inputs.get(\"first_task\", \"Make a todo list\")\n",
|
||||
" self.add_task({\"task_id\": 1, \"task_name\": first_task})\n",
|
||||
" num_iters = 0\n",
|
||||
" while True:\n",
|
||||
" if self.task_list:\n",
|
||||
" self.print_task_list()\n",
|
||||
"\n",
|
||||
" # Step 1: Pull the first task\n",
|
||||
" task = self.task_list.popleft()\n",
|
||||
" self.print_next_task(task)\n",
|
||||
"\n",
|
||||
" # Step 2: Execute the task\n",
|
||||
" result = execute_task(\n",
|
||||
" self.vectorstore, self.execution_chain, objective, task[\"task_name\"]\n",
|
||||
" )\n",
|
||||
" this_task_id = int(task[\"task_id\"])\n",
|
||||
" self.print_task_result(result)\n",
|
||||
"\n",
|
||||
" # Step 3: Store the result in Pinecone\n",
|
||||
" result_id = f\"result_{task['task_id']}_{num_iters}\"\n",
|
||||
" self.vectorstore.add_texts(\n",
|
||||
" texts=[result],\n",
|
||||
" metadatas=[{\"task\": task[\"task_name\"]}],\n",
|
||||
" ids=[result_id],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Step 4: Create new tasks and reprioritize task list\n",
|
||||
" new_tasks = get_next_task(\n",
|
||||
" self.task_creation_chain,\n",
|
||||
" result,\n",
|
||||
" task[\"task_name\"],\n",
|
||||
" [t[\"task_name\"] for t in self.task_list],\n",
|
||||
" objective,\n",
|
||||
" )\n",
|
||||
" for new_task in new_tasks:\n",
|
||||
" self.task_id_counter += 1\n",
|
||||
" new_task.update({\"task_id\": self.task_id_counter})\n",
|
||||
" self.add_task(new_task)\n",
|
||||
" self.task_list = deque(\n",
|
||||
" prioritize_tasks(\n",
|
||||
" self.task_prioritization_chain,\n",
|
||||
" this_task_id,\n",
|
||||
" list(self.task_list),\n",
|
||||
" objective,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" num_iters += 1\n",
|
||||
" if self.max_iterations is not None and num_iters == self.max_iterations:\n",
|
||||
" print(\n",
|
||||
" \"\\033[91m\\033[1m\" + \"\\n*****TASK ENDING*****\\n\" + \"\\033[0m\\033[0m\"\n",
|
||||
" )\n",
|
||||
" break\n",
|
||||
" return {}\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def from_llm(\n",
|
||||
" cls, llm: BaseLLM, vectorstore: VectorStore, verbose: bool = False, **kwargs\n",
|
||||
" ) -> \"BabyAGI\":\n",
|
||||
" \"\"\"Initialize the BabyAGI Controller.\"\"\"\n",
|
||||
" task_creation_chain = TaskCreationChain.from_llm(llm, verbose=verbose)\n",
|
||||
" task_prioritization_chain = TaskPrioritizationChain.from_llm(\n",
|
||||
" llm, verbose=verbose\n",
|
||||
" )\n",
|
||||
" llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
|
||||
" tool_names = [tool.name for tool in tools]\n",
|
||||
" agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)\n",
|
||||
" agent_executor = AgentExecutor.from_agent_and_tools(\n",
|
||||
" agent=agent, tools=tools, verbose=True\n",
|
||||
" )\n",
|
||||
" return cls(\n",
|
||||
" task_creation_chain=task_creation_chain,\n",
|
||||
" task_prioritization_chain=task_prioritization_chain,\n",
|
||||
" execution_chain=agent_executor,\n",
|
||||
" vectorstore=vectorstore,\n",
|
||||
" **kwargs,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "05ba762e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Run the BabyAGI\n",
|
||||
"\n",
|
||||
"Now it's time to create the BabyAGI controller and watch it try to accomplish your objective."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"id": "3d220b69",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"OBJECTIVE = \"Write a weather report for SF today\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"id": "8a8e5543",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"id": "3d69899b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Logging of LLMChains\n",
|
||||
"verbose = False\n",
|
||||
"# If None, will keep on going forever\n",
|
||||
"max_iterations: Optional[int] = 3\n",
|
||||
"baby_agi = BabyAGI.from_llm(\n",
|
||||
" llm=llm, vectorstore=vectorstore, verbose=verbose, max_iterations=max_iterations\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"id": "f7957b51",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[95m\u001b[1m\n",
|
||||
"*****TASK LIST*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"1: Make a todo list\n",
|
||||
"\u001b[92m\u001b[1m\n",
|
||||
"*****NEXT TASK*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"1: Make a todo list\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mThought: I need to gather data on the current weather conditions in SF\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: Current weather conditions in SF\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mHigh 67F. Winds WNW at 10 to 15 mph. Clear to partly cloudy.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to make a todo list\n",
|
||||
"Action: TODO\n",
|
||||
"Action Input: Write a weather report for SF today\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"1. Research current weather conditions in San Francisco\n",
|
||||
"2. Gather data on temperature, humidity, wind speed, and other relevant weather conditions\n",
|
||||
"3. Analyze data to determine current weather trends\n",
|
||||
"4. Write a brief introduction to the weather report\n",
|
||||
"5. Describe current weather conditions in San Francisco\n",
|
||||
"6. Discuss any upcoming weather changes\n",
|
||||
"7. Summarize the weather report\n",
|
||||
"8. Proofread and edit the report\n",
|
||||
"9. Submit the report\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: A weather report for SF today should include research on current weather conditions in San Francisco, gathering data on temperature, humidity, wind speed, and other relevant weather conditions, analyzing data to determine current weather trends, writing a brief introduction to the weather report, describing current weather conditions in San Francisco, discussing any upcoming weather changes, summarizing the weather report, proofreading and editing the report, and submitting the report.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\u001b[93m\u001b[1m\n",
|
||||
"*****TASK RESULT*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"A weather report for SF today should include research on current weather conditions in San Francisco, gathering data on temperature, humidity, wind speed, and other relevant weather conditions, analyzing data to determine current weather trends, writing a brief introduction to the weather report, describing current weather conditions in San Francisco, discussing any upcoming weather changes, summarizing the weather report, proofreading and editing the report, and submitting the report.\n",
|
||||
"\u001b[95m\u001b[1m\n",
|
||||
"*****TASK LIST*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"2: Gather data on temperature, humidity, wind speed, and other relevant weather conditions\n",
|
||||
"3: Analyze data to determine current weather trends\n",
|
||||
"4: Write a brief introduction to the weather report\n",
|
||||
"5: Describe current weather conditions in San Francisco\n",
|
||||
"6: Discuss any upcoming weather changes\n",
|
||||
"7: Summarize the weather report\n",
|
||||
"8: Proofread and edit the report\n",
|
||||
"9: Submit the report\n",
|
||||
"1: Research current weather conditions in San Francisco\n",
|
||||
"\u001b[92m\u001b[1m\n",
|
||||
"*****NEXT TASK*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"2: Gather data on temperature, humidity, wind speed, and other relevant weather conditions\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mThought: I need to search for the current weather conditions in SF\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: Current weather conditions in SF\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mHigh 67F. Winds WNW at 10 to 15 mph. Clear to partly cloudy.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to make a todo list\n",
|
||||
"Action: TODO\n",
|
||||
"Action Input: Create a weather report for SF today\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"1. Gather current weather data for SF, including temperature, wind speed, humidity, and precipitation.\n",
|
||||
"2. Research historical weather data for SF to compare current conditions.\n",
|
||||
"3. Analyze current and historical data to determine any trends or patterns.\n",
|
||||
"4. Create a visual representation of the data, such as a graph or chart.\n",
|
||||
"5. Write a summary of the weather report, including key findings and any relevant information.\n",
|
||||
"6. Publish the weather report on a website or other platform.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: Today in San Francisco, the temperature is 67F with winds WNW at 10 to 15 mph. The sky is clear to partly cloudy.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\u001b[93m\u001b[1m\n",
|
||||
"*****TASK RESULT*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"Today in San Francisco, the temperature is 67F with winds WNW at 10 to 15 mph. The sky is clear to partly cloudy.\n",
|
||||
"\u001b[95m\u001b[1m\n",
|
||||
"*****TASK LIST*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"3: Research current weather conditions in San Francisco\n",
|
||||
"4: Compare the current weather conditions in San Francisco to the average for this time of year.\n",
|
||||
"5: Identify any potential weather-related hazards in the area.\n",
|
||||
"6: Research any historical weather patterns in San Francisco.\n",
|
||||
"7: Analyze data to determine current weather trends\n",
|
||||
"8: Include any relevant data from nearby cities in the report.\n",
|
||||
"9: Include any relevant data from the National Weather Service in the report.\n",
|
||||
"10: Include any relevant data from local news sources in the report.\n",
|
||||
"11: Include any relevant data from online weather sources in the report.\n",
|
||||
"12: Include any relevant data from local meteorologists in the report.\n",
|
||||
"13: Include any relevant data from local weather stations in the report.\n",
|
||||
"14: Include any relevant data from satellite images in the report.\n",
|
||||
"15: Describe current weather conditions in San Francisco\n",
|
||||
"16: Discuss any upcoming weather changes\n",
|
||||
"17: Write a brief introduction to the weather report\n",
|
||||
"18: Summarize the weather report\n",
|
||||
"19: Proofread and edit the report\n",
|
||||
"20: Submit the report\n",
|
||||
"\u001b[92m\u001b[1m\n",
|
||||
"*****NEXT TASK*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"3: Research current weather conditions in San Francisco\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mThought: I need to search for current weather conditions in San Francisco\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: Current weather conditions in San Francisco\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mTodaySun 04/09 High 67 · 1% Precip. ; TonightSun 04/09 Low 49 · 9% Precip. ; TomorrowMon 04/10 High 64 · 11% Precip.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: Today in San Francisco, the high temperature is 67 degrees with 1% chance of precipitation. The low temperature tonight is 49 degrees with 9% chance of precipitation. Tomorrow's high temperature is 64 degrees with 11% chance of precipitation.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\u001b[93m\u001b[1m\n",
|
||||
"*****TASK RESULT*****\n",
|
||||
"\u001b[0m\u001b[0m\n",
|
||||
"Today in San Francisco, the high temperature is 67 degrees with 1% chance of precipitation. The low temperature tonight is 49 degrees with 9% chance of precipitation. Tomorrow's high temperature is 64 degrees with 11% chance of precipitation.\n",
|
||||
"\u001b[91m\u001b[1m\n",
|
||||
"*****TASK ENDING*****\n",
|
||||
"\u001b[0m\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'objective': 'Write a weather report for SF today'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"baby_agi({\"objective\": OBJECTIVE})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "898a210b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -66,11 +66,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from git import Repo\n",
|
||||
"# from git import Repo\n",
|
||||
"from langchain.text_splitter import Language\n",
|
||||
"from langchain.document_loaders.generic import GenericLoader\n",
|
||||
"from langchain.document_loaders.parsers import LanguageParser"
|
||||
@@ -78,13 +78,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Clone\n",
|
||||
"repo_path = \"/Users/rlm/Desktop/test_repo\"\n",
|
||||
"repo = Repo.clone_from(\"https://github.com/hwchase17/langchain\", to_path=repo_path)"
|
||||
"# repo = Repo.clone_from(\"https://github.com/hwchase17/langchain\", to_path=repo_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -100,7 +100,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -109,7 +109,7 @@
|
||||
"1293"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -139,7 +139,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -148,7 +148,7 @@
|
||||
"3748"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -187,7 +187,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -333,68 +333,673 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Private chat\n",
|
||||
"### Open source LLMs\n",
|
||||
"\n",
|
||||
"We can use [Code LLaMA](https://about.fb.com/news/2023/08/code-llama-ai-for-coding/) via the Ollama integration.\n",
|
||||
"We can use [Code LLaMA](https://about.fb.com/news/2023/08/code-llama-ai-for-coding/) via LLamaCPP or [Ollama integration](https://ollama.ai/blog/run-code-llama-locally).\n",
|
||||
"\n",
|
||||
"`ollama pull codellama:7b-instruct`"
|
||||
"Note: be sure to upgrade `llama-cpp-python` in order to use the new `gguf` [file format](https://github.com/abetlen/llama-cpp-python/pull/633).\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"CMAKE_ARGS=\"-DLLAMA_METAL=on\" FORCE_CMAKE=1 /Users/rlm/miniforge3/envs/llama2/bin/pip install -U llama-cpp-python --no-cache-dir\n",
|
||||
"```\n",
|
||||
" \n",
|
||||
"Check out the latest code-llama models [here](https://huggingface.co/TheBloke/CodeLlama-13B-Instruct-GGUF/tree/main)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import Ollama\n",
|
||||
"from langchain.llms import LlamaCpp\n",
|
||||
"from langchain import PromptTemplate, LLMChain\n",
|
||||
"from langchain.callbacks.manager import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler \n",
|
||||
"llm = Ollama(model=\"codellama:7b-instruct\", \n",
|
||||
" callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]))\n",
|
||||
"memory = ConversationSummaryMemory(llm=llm,memory_key=\"chat_history\",return_messages=True)\n",
|
||||
"qa_llama=ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory)"
|
||||
"from langchain.memory import ConversationSummaryMemory\n",
|
||||
"from langchain.chains import ConversationalRetrievalChain \n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"llama_model_loader: loaded meta data with 17 key-value pairs and 363 tensors from /Users/rlm/Desktop/Code/llama/code-llama/codellama-13b-instruct.Q4_K_M.gguf (version GGUF V1 (latest))\n",
|
||||
"llama_model_loader: - tensor 0: token_embd.weight q4_0 [ 5120, 32016, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 1: output_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 2: output.weight f16 [ 5120, 32016, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 3: blk.0.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 4: blk.0.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 5: blk.0.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 6: blk.0.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 7: blk.0.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 8: blk.0.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 9: blk.0.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 10: blk.0.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 11: blk.0.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 12: blk.1.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 13: blk.1.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 14: blk.1.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 15: blk.1.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 16: blk.1.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 17: blk.1.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 18: blk.1.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 19: blk.1.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 20: blk.1.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 21: blk.2.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 22: blk.2.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 23: blk.2.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 24: blk.2.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 25: blk.2.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 26: blk.2.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 27: blk.2.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 28: blk.2.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 29: blk.2.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 30: blk.3.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 31: blk.3.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 32: blk.3.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 33: blk.3.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 34: blk.3.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 35: blk.3.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 36: blk.3.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 37: blk.3.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 38: blk.3.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 39: blk.4.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 40: blk.4.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 41: blk.4.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 42: blk.4.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 43: blk.4.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 44: blk.4.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 45: blk.4.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 46: blk.4.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 47: blk.4.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 48: blk.5.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 49: blk.5.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 50: blk.5.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 51: blk.5.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 52: blk.5.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 53: blk.5.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 54: blk.5.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 55: blk.5.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 56: blk.5.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 57: blk.6.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 58: blk.6.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 59: blk.6.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 60: blk.6.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 61: blk.6.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 62: blk.6.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 63: blk.6.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 64: blk.6.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 65: blk.6.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 66: blk.7.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 67: blk.7.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 68: blk.7.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 69: blk.7.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 70: blk.7.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 71: blk.7.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 72: blk.7.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 73: blk.7.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 74: blk.7.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 75: blk.8.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 76: blk.8.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 77: blk.8.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 78: blk.8.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 79: blk.8.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 80: blk.8.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 81: blk.8.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 82: blk.8.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 83: blk.8.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 84: blk.9.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 85: blk.9.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 86: blk.9.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 87: blk.9.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 88: blk.9.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 89: blk.9.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 90: blk.9.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 91: blk.9.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 92: blk.9.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 93: blk.10.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 94: blk.10.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 95: blk.10.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 96: blk.10.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 97: blk.10.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 98: blk.10.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 99: blk.10.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 100: blk.10.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 101: blk.10.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 102: blk.11.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 103: blk.11.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 104: blk.11.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 105: blk.11.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 106: blk.11.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 107: blk.11.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 108: blk.11.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 109: blk.11.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 110: blk.11.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 111: blk.12.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 112: blk.12.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 113: blk.12.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 114: blk.12.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 115: blk.12.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 116: blk.12.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 117: blk.12.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 118: blk.12.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 119: blk.12.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 120: blk.13.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 121: blk.13.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 122: blk.13.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 123: blk.13.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 124: blk.13.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 125: blk.13.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 126: blk.13.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 127: blk.13.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 128: blk.13.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 129: blk.14.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 130: blk.14.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 131: blk.14.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 132: blk.14.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 133: blk.14.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 134: blk.14.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 135: blk.14.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 136: blk.14.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 137: blk.14.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 138: blk.15.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 139: blk.15.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 140: blk.15.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 141: blk.15.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 142: blk.15.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 143: blk.15.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 144: blk.15.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 145: blk.15.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 146: blk.15.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 147: blk.16.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 148: blk.16.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 149: blk.16.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 150: blk.16.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 151: blk.16.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 152: blk.16.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 153: blk.16.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 154: blk.16.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 155: blk.16.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 156: blk.17.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 157: blk.17.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 158: blk.17.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 159: blk.17.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 160: blk.17.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 161: blk.17.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 162: blk.17.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 163: blk.17.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 164: blk.17.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 165: blk.18.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 166: blk.18.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 167: blk.18.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 168: blk.18.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 169: blk.18.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 170: blk.18.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 171: blk.18.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 172: blk.18.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 173: blk.18.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 174: blk.19.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 175: blk.19.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 176: blk.19.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 177: blk.19.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 178: blk.19.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 179: blk.19.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 180: blk.19.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 181: blk.19.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 182: blk.19.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 183: blk.20.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 184: blk.20.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 185: blk.20.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 186: blk.20.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 187: blk.20.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 188: blk.20.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 189: blk.20.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 190: blk.20.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 191: blk.20.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 192: blk.21.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 193: blk.21.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 194: blk.21.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 195: blk.21.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 196: blk.21.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 197: blk.21.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 198: blk.21.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 199: blk.21.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 200: blk.21.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 201: blk.22.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 202: blk.22.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 203: blk.22.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 204: blk.22.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 205: blk.22.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 206: blk.22.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 207: blk.22.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 208: blk.22.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 209: blk.22.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 210: blk.23.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 211: blk.23.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 212: blk.23.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 213: blk.23.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 214: blk.23.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 215: blk.23.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 216: blk.23.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 217: blk.23.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 218: blk.23.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 219: blk.24.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 220: blk.24.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 221: blk.24.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 222: blk.24.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 223: blk.24.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 224: blk.24.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 225: blk.24.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 226: blk.24.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 227: blk.24.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 228: blk.25.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 229: blk.25.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 230: blk.25.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 231: blk.25.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 232: blk.25.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 233: blk.25.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 234: blk.25.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 235: blk.25.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 236: blk.25.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 237: blk.26.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 238: blk.26.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 239: blk.26.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 240: blk.26.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 241: blk.26.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 242: blk.26.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 243: blk.26.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 244: blk.26.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 245: blk.26.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 246: blk.27.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 247: blk.27.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 248: blk.27.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 249: blk.27.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 250: blk.27.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 251: blk.27.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 252: blk.27.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 253: blk.27.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 254: blk.27.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 255: blk.28.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 256: blk.28.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 257: blk.28.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 258: blk.28.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 259: blk.28.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 260: blk.28.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 261: blk.28.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 262: blk.28.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 263: blk.28.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 264: blk.29.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 265: blk.29.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 266: blk.29.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 267: blk.29.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 268: blk.29.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 269: blk.29.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 270: blk.29.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 271: blk.29.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 272: blk.29.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 273: blk.30.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 274: blk.30.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 275: blk.30.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 276: blk.30.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 277: blk.30.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 278: blk.30.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 279: blk.30.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 280: blk.30.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 281: blk.30.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 282: blk.31.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 283: blk.31.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 284: blk.31.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 285: blk.31.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 286: blk.31.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 287: blk.31.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 288: blk.31.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 289: blk.31.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 290: blk.31.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 291: blk.32.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 292: blk.32.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 293: blk.32.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 294: blk.32.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 295: blk.32.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 296: blk.32.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 297: blk.32.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 298: blk.32.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 299: blk.32.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 300: blk.33.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 301: blk.33.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 302: blk.33.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 303: blk.33.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 304: blk.33.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 305: blk.33.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 306: blk.33.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 307: blk.33.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 308: blk.33.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 309: blk.34.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 310: blk.34.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 311: blk.34.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 312: blk.34.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 313: blk.34.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 314: blk.34.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 315: blk.34.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 316: blk.34.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 317: blk.34.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 318: blk.35.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 319: blk.35.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 320: blk.35.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 321: blk.35.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 322: blk.35.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 323: blk.35.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 324: blk.35.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 325: blk.35.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 326: blk.35.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 327: blk.36.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 328: blk.36.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 329: blk.36.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 330: blk.36.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 331: blk.36.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 332: blk.36.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 333: blk.36.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 334: blk.36.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 335: blk.36.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 336: blk.37.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 337: blk.37.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 338: blk.37.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 339: blk.37.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 340: blk.37.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 341: blk.37.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 342: blk.37.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 343: blk.37.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 344: blk.37.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 345: blk.38.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 346: blk.38.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 347: blk.38.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 348: blk.38.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 349: blk.38.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 350: blk.38.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 351: blk.38.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 352: blk.38.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 353: blk.38.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 354: blk.39.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 355: blk.39.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 356: blk.39.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 357: blk.39.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 358: blk.39.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 359: blk.39.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 360: blk.39.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 361: blk.39.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - tensor 362: blk.39.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
|
||||
"llama_model_loader: - kv 0: general.architecture str \n",
|
||||
"llama_model_loader: - kv 1: general.name str \n",
|
||||
"llama_model_loader: - kv 2: llama.context_length u32 \n",
|
||||
"llama_model_loader: - kv 3: llama.embedding_length u32 \n",
|
||||
"llama_model_loader: - kv 4: llama.block_count u32 \n",
|
||||
"llama_model_loader: - kv 5: llama.feed_forward_length u32 \n",
|
||||
"llama_model_loader: - kv 6: llama.rope.dimension_count u32 \n",
|
||||
"llama_model_loader: - kv 7: llama.attention.head_count u32 \n",
|
||||
"llama_model_loader: - kv 8: llama.attention.head_count_kv u32 \n",
|
||||
"llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 \n",
|
||||
"llama_model_loader: - kv 10: llama.rope.freq_base f32 \n",
|
||||
"llama_model_loader: - kv 11: general.file_type u32 \n",
|
||||
"llama_model_loader: - kv 12: tokenizer.ggml.model str \n",
|
||||
"llama_model_loader: - kv 13: tokenizer.ggml.tokens arr \n",
|
||||
"llama_model_loader: - kv 14: tokenizer.ggml.scores arr \n",
|
||||
"llama_model_loader: - kv 15: tokenizer.ggml.token_type arr \n",
|
||||
"llama_model_loader: - kv 16: general.quantization_version u32 \n",
|
||||
"llama_model_loader: - type f32: 81 tensors\n",
|
||||
"llama_model_loader: - type f16: 1 tensors\n",
|
||||
"llama_model_loader: - type q4_0: 1 tensors\n",
|
||||
"llama_model_loader: - type q4_K: 240 tensors\n",
|
||||
"llama_model_loader: - type q6_K: 40 tensors\n",
|
||||
"llm_load_print_meta: format = GGUF V1 (latest)\n",
|
||||
"llm_load_print_meta: arch = llama\n",
|
||||
"llm_load_print_meta: vocab type = SPM\n",
|
||||
"llm_load_print_meta: n_vocab = 32016\n",
|
||||
"llm_load_print_meta: n_merges = 0\n",
|
||||
"llm_load_print_meta: n_ctx_train = 16384\n",
|
||||
"llm_load_print_meta: n_ctx = 5000\n",
|
||||
"llm_load_print_meta: n_embd = 5120\n",
|
||||
"llm_load_print_meta: n_head = 40\n",
|
||||
"llm_load_print_meta: n_head_kv = 40\n",
|
||||
"llm_load_print_meta: n_layer = 40\n",
|
||||
"llm_load_print_meta: n_rot = 128\n",
|
||||
"llm_load_print_meta: n_gqa = 1\n",
|
||||
"llm_load_print_meta: f_norm_eps = 1.0e-05\n",
|
||||
"llm_load_print_meta: f_norm_rms_eps = 1.0e-05\n",
|
||||
"llm_load_print_meta: n_ff = 13824\n",
|
||||
"llm_load_print_meta: freq_base = 1000000.0\n",
|
||||
"llm_load_print_meta: freq_scale = 1\n",
|
||||
"llm_load_print_meta: model type = 13B\n",
|
||||
"llm_load_print_meta: model ftype = mostly Q4_K - Medium\n",
|
||||
"llm_load_print_meta: model size = 13.02 B\n",
|
||||
"llm_load_print_meta: general.name = LLaMA\n",
|
||||
"llm_load_print_meta: BOS token = 1 '<s>'\n",
|
||||
"llm_load_print_meta: EOS token = 2 '</s>'\n",
|
||||
"llm_load_print_meta: UNK token = 0 '<unk>'\n",
|
||||
"llm_load_print_meta: LF token = 13 '<0x0A>'\n",
|
||||
"llm_load_tensors: ggml ctx size = 0.11 MB\n",
|
||||
"llm_load_tensors: mem required = 7685.49 MB (+ 3906.25 MB per state)\n",
|
||||
".................................................................................................\n",
|
||||
"llama_new_context_with_model: kv self size = 3906.25 MB\n",
|
||||
"ggml_metal_init: allocating\n",
|
||||
"ggml_metal_init: loading '/Users/rlm/miniforge3/envs/llama2/lib/python3.9/site-packages/llama_cpp/ggml-metal.metal'\n",
|
||||
"ggml_metal_init: loaded kernel_add 0x12126dd00 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_add_row 0x12126d610 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul 0x12126f2a0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_row 0x12126f500 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_scale 0x12126f760 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_silu 0x12126fe40 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_relu 0x1212700a0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_gelu 0x121270300 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_soft_max 0x121270560 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_diag_mask_inf 0x1212707c0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_f16 0x121270a20 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q4_0 0x121270c80 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q4_1 0x121270ee0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q8_0 0x121271140 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q2_K 0x1212713a0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q3_K 0x121271600 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q4_K 0x121271860 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q5_K 0x121271ac0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q6_K 0x121271d20 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_rms_norm 0x121271f80 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_norm 0x1212721e0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_f16_f32 0x121272440 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q4_0_f32 0x1212726a0 | th_max = 896 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q4_1_f32 0x121272900 | th_max = 896 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q8_0_f32 0x121272b60 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q2_K_f32 0x121272dc0 | th_max = 640 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q3_K_f32 0x121273020 | th_max = 704 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q4_K_f32 0x121273280 | th_max = 576 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q5_K_f32 0x1212734e0 | th_max = 576 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q6_K_f32 0x121273740 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_f16_f32 0x1212739a0 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q4_0_f32 0x121273c00 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q8_0_f32 0x121273e60 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q4_1_f32 0x1212740c0 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q2_K_f32 0x121274320 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q3_K_f32 0x121274580 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q4_K_f32 0x1212747e0 | th_max = 768 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q5_K_f32 0x121274a40 | th_max = 704 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mm_q6_K_f32 0x121274ca0 | th_max = 704 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_rope 0x121274f00 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_alibi_f32 0x121275160 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_cpy_f32_f16 0x1212753c0 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_cpy_f32_f32 0x121275620 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: loaded kernel_cpy_f16_f16 0x121275880 | th_max = 1024 | th_width = 32\n",
|
||||
"ggml_metal_init: recommendedMaxWorkingSetSize = 21845.34 MB\n",
|
||||
"ggml_metal_init: hasUnifiedMemory = true\n",
|
||||
"ggml_metal_init: maxTransferRate = built-in GPU\n",
|
||||
"llama_new_context_with_model: compute buffer total size = 442.03 MB\n",
|
||||
"llama_new_context_with_model: max tensor size = 312.66 MB\n",
|
||||
"ggml_metal_add_buffer: allocated 'data ' buffer, size = 7686.00 MB, (20243.77 / 21845.34)\n",
|
||||
"ggml_metal_add_buffer: allocated 'eval ' buffer, size = 1.42 MB, (20245.19 / 21845.34)\n",
|
||||
"ggml_metal_add_buffer: allocated 'kv ' buffer, size = 3908.25 MB, (24153.44 / 21845.34), warning: current allocated size is greater than the recommended max working set size\n",
|
||||
"AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | \n",
|
||||
"ggml_metal_add_buffer: allocated 'alloc ' buffer, size = 440.64 MB, (24594.08 / 21845.34), warning: current allocated size is greater than the recommended max working set size\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
|
||||
"llm = LlamaCpp(\n",
|
||||
" model_path=\"/Users/rlm/Desktop/Code/llama/code-llama/codellama-13b-instruct.Q4_K_M.gguf\",\n",
|
||||
" n_ctx=5000,\n",
|
||||
" n_gpu_layers=1,\n",
|
||||
" n_batch=512,\n",
|
||||
" f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls\n",
|
||||
" callback_manager=callback_manager,\n",
|
||||
" verbose=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Llama.generate: prefix-match hit\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" \"How can I initialize a ReAct agent?\" To initialize a ReAct agent, you can use the `ReActAgent.from_llm_and_tools()` class method. This method takes two arguments: the LLM and a list of tools.\n",
|
||||
"Here is an example of how to initialize a ReAct agent with the OpenAI language model and the \"Search\" tool:\n",
|
||||
"from langchain.agents.mrkl.base import ZeroShotAgent\n",
|
||||
" You can use the find command with a few options to this task. Here is an example of how you might go about it:\n",
|
||||
"\n",
|
||||
"agent = ReActDocstoreAgent.from_llm_and_tools(OpenAIFunctionsAgent(), [Tool(\"Search\")]])\n",
|
||||
"find . -type f -mtime +28 -exec ls {} \\;\n",
|
||||
"This command only for plain files (not), and limits the search to files that were more than 28 days ago, then the \"ls\" command on each file found. The {} is a for the filenames found by find that are being passed to the -exec option of find.\n",
|
||||
"\n",
|
||||
" The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential."
|
||||
"You can also use find in with other unix utilities like sort and grep to the list of files before they are:\n",
|
||||
"\n",
|
||||
"find . -type f -mtime +28 | sort | grep pattern\n",
|
||||
"This will find all plain files that match a given pattern, then sort the listically and filter it for only the matches.\n",
|
||||
"\n",
|
||||
"Answer: `find` is pretty with its search. The should work as well:\n",
|
||||
"\n",
|
||||
"\\begin{code}\n",
|
||||
"ls -l $(find . -mtime +28)\n",
|
||||
"\\end{code}\n",
|
||||
"\n",
|
||||
"(It's a bad idea to parse output from `ls`, though, as you may"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"llama_print_timings: load time = 1074.43 ms\n",
|
||||
"llama_print_timings: sample time = 180.71 ms / 256 runs ( 0.71 ms per token, 1416.67 tokens per second)\n",
|
||||
"llama_print_timings: prompt eval time = 0.00 ms / 1 tokens ( 0.00 ms per token, inf tokens per second)\n",
|
||||
"llama_print_timings: eval time = 9593.04 ms / 256 runs ( 37.47 ms per token, 26.69 tokens per second)\n",
|
||||
"llama_print_timings: total time = 10139.91 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' To initialize a ReAct agent, you can use the `ReActAgent.from_llm_and_tools()` class method. This method takes two arguments: the LLM and a list of tools.\\nHere is an example of how to initialize a ReAct agent with the OpenAI language model and the \"Search\" tool:\\nfrom langchain.agents.mrkl.base import ZeroShotAgent\\n\\nagent = ReActDocstoreAgent.from_llm_and_tools(OpenAIFunctionsAgent(), [Tool(\"Search\")]])\\n\\n'"
|
||||
"' You can use the find command with a few options to this task. Here is an example of how you might go about it:\\n\\nfind . -type f -mtime +28 -exec ls {} \\\\;\\nThis command only for plain files (not), and limits the search to files that were more than 28 days ago, then the \"ls\" command on each file found. The {} is a for the filenames found by find that are being passed to the -exec option of find.\\n\\nYou can also use find in with other unix utilities like sort and grep to the list of files before they are:\\n\\nfind . -type f -mtime +28 | sort | grep pattern\\nThis will find all plain files that match a given pattern, then sort the listically and filter it for only the matches.\\n\\nAnswer: `find` is pretty with its search. The should work as well:\\n\\n\\\\begin{code}\\nls -l $(find . -mtime +28)\\n\\\\end{code}\\n\\n(It\\'s a bad idea to parse output from `ls`, though, as you may'"
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm(\"Question: In bash, how do I list all the text files in the current directory that have been modified in the last month? Answer:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Llama.generate: prefix-match hit\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" You can use the `ReActAgent` class and pass it the desired tools as, for example, you would do like this to create an agent with the `Lookup` and `Search` tool:\n",
|
||||
"```python\n",
|
||||
"from langchain.agents.react import ReActAgent\n",
|
||||
"from langchain.tools.lookup import Lookup\n",
|
||||
"from langchain.tools.search import Search\n",
|
||||
"ReActAgent(Lookup(), Search())\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"llama_print_timings: load time = 1074.43 ms\n",
|
||||
"llama_print_timings: sample time = 65.46 ms / 94 runs ( 0.70 ms per token, 1435.95 tokens per second)\n",
|
||||
"llama_print_timings: prompt eval time = 15975.57 ms / 1408 tokens ( 11.35 ms per token, 88.13 tokens per second)\n",
|
||||
"llama_print_timings: eval time = 4772.57 ms / 93 runs ( 51.32 ms per token, 19.49 tokens per second)\n",
|
||||
"llama_print_timings: total time = 20959.57 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'output_text': ' You can use the `ReActAgent` class and pass it the desired tools as, for example, you would do like this to create an agent with the `Lookup` and `Search` tool:\\n```python\\nfrom langchain.agents.react import ReActAgent\\nfrom langchain.tools.lookup import Lookup\\nfrom langchain.tools.search import Search\\nReActAgent(Lookup(), Search())\\n```'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"\n",
|
||||
"# Prompt\n",
|
||||
"template = \"\"\"Use the following pieces of context to answer the question at the end. \n",
|
||||
"If you don't know the answer, just say that you don't know, don't try to make up an answer. \n",
|
||||
"Use three sentences maximum and keep the answer as concise as possible. \n",
|
||||
"{context}\n",
|
||||
"Question: {question}\n",
|
||||
"Helpful Answer:\"\"\"\n",
|
||||
"QA_CHAIN_PROMPT = PromptTemplate(\n",
|
||||
" input_variables=[\"context\", \"question\"],\n",
|
||||
" template=template,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Docs\n",
|
||||
"question = \"How can I initialize a ReAct agent?\"\n",
|
||||
"result = qa_llama(question)\n",
|
||||
"result['answer']"
|
||||
"docs = retriever.get_relevant_documents(question)\n",
|
||||
"\n",
|
||||
"# Chain\n",
|
||||
"chain = load_qa_chain(llm, chain_type=\"stuff\", prompt=QA_CHAIN_PROMPT)\n",
|
||||
"\n",
|
||||
"# Run\n",
|
||||
"chain({\"input_documents\": docs, \"question\": question}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can view the [LangSmith trace](https://smith.langchain.com/public/fd24c734-e365-4a09-b883-cdbc7dcfa582/r) to sanity check the result relative to what was retrieved."
|
||||
"Here's the trace [RAG](https://smith.langchain.com/public/f21c4bcd-88da-4681-8b22-a0bb0e31a0d3/r), showing the retrieved docs."
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -418,5 +1023,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
"source": [
|
||||
"## Quickstart\n",
|
||||
"\n",
|
||||
"OpenAI funtions are one way to get started with extraction.\n",
|
||||
"OpenAI functions are one way to get started with extraction.\n",
|
||||
"\n",
|
||||
"Define a schema that specifies the properties we want to extract from the LLM output.\n",
|
||||
"\n",
|
||||
@@ -122,7 +122,7 @@
|
||||
"id": "6f7eb826",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Option 1: OpenAI funtions\n",
|
||||
"## Option 1: OpenAI functions\n",
|
||||
"\n",
|
||||
"### Looking under the hood\n",
|
||||
"\n",
|
||||
|
||||
718
docs/extras/use_cases/more/agents/agents.ipynb
Normal file
@@ -0,0 +1,718 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "842dd272",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Agents\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/langchain-ai/langchain/blob/master/docs/extras/use_cases/more/agents/agents.ipynb)\n",
|
||||
"\n",
|
||||
"## Use case \n",
|
||||
"\n",
|
||||
"LLM-based agents are powerful general problem solvers.\n",
|
||||
"\n",
|
||||
"The [primary LLM agent components](https://lilianweng.github.io/posts/2023-06-23-agent/) include at least 3 things:\n",
|
||||
"\n",
|
||||
"* `Planning`: The ability to break down tasks into smaller sub-goals\n",
|
||||
"* `Memory`: The ability to retain and recall information\n",
|
||||
"* `Tools`: The ability to get information from external sources (e.g., APIs)\n",
|
||||
"\n",
|
||||
"Unlike LLMs simply connected to [APIs](/docs/use_cases/apis/apis), agents [can](https://www.youtube.com/watch?v=DWUdGhRrv2c):\n",
|
||||
"\n",
|
||||
"* Self-correct\n",
|
||||
"* Handle multi-hop tasks (several intermediate \"hops\" or steps to arrive at a conclusion)\n",
|
||||
"* Tackle long time horizon tasks (that require access to long-term memory)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Overview \n",
|
||||
"\n",
|
||||
"LangChain has [many agent types](/docs/modules/agents/agent_types/).\n",
|
||||
"\n",
|
||||
"Nearly all agents will use the following components:\n",
|
||||
" \n",
|
||||
"**Planning**\n",
|
||||
" \n",
|
||||
"* `Prompt`: Can given the LLM [personality](https://arxiv.org/pdf/2304.03442.pdf), context (e.g, via retrieval from memory), or strategies for learninng (e.g., [chain-of-thought](https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/#chain-of-thought-cot)).\n",
|
||||
"* `Agent` Responsible for deciding what step to take next using an LLM with the `Prompt`\n",
|
||||
"\n",
|
||||
"**Memory**\n",
|
||||
"\n",
|
||||
"* This can be short or long-term, allowing the agent to persist information.\n",
|
||||
"\n",
|
||||
"**Tools**\n",
|
||||
"\n",
|
||||
"* Tools are functions that an agent can call.\n",
|
||||
"\n",
|
||||
"But, there are some taxonomic differences:\n",
|
||||
"\n",
|
||||
"* `Action agents`: Designed to decide the sequence of actions (tool use) (e.g., OpenAI functions agents, ReAct agents).\n",
|
||||
"* `Simulation agents`: Designed for role-play often in simulated enviorment (e.g., Generative Agents, CAMEL).\n",
|
||||
"* `Autonomous agents`: Designed for indepdent execution towards long term goals (e.g., BabyAGI, Auto-GPT).\n",
|
||||
"\n",
|
||||
"This will focus on `Action agents`.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Quickstart "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a704c7a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install langchain openai google-search-results\n",
|
||||
"\n",
|
||||
"# Set env var OPENAI_API_KEY and SERPAPI_API_KEY or load from a .env file\n",
|
||||
"# import dotenv\n",
|
||||
"\n",
|
||||
"# dotenv.load_env()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "639d41ad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`Tools`\n",
|
||||
"\n",
|
||||
"LangChain has [many tools](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/load_tools.py) for Agents that we can load easily.\n",
|
||||
"\n",
|
||||
"Let's load search and a calcultor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c60001c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Tool\n",
|
||||
"from langchain.agents import load_tools\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "431ba30b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`Agent`\n",
|
||||
"\n",
|
||||
"The [`OPENAI_FUNCTIONS` agent](/docs/modules/agents/agent_types/openai_functions_agent) is a good action agent to start with.\n",
|
||||
"\n",
|
||||
"OpenAI models have been fine-tuned to recognize when function should be called."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d636395f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'As of 2023, the estimated population of Canada is approximately 39,858,480 people.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Prompt\n",
|
||||
"from langchain.agents import AgentExecutor\n",
|
||||
"from langchain.schema import SystemMessage\n",
|
||||
"from langchain.agents import OpenAIFunctionsAgent\n",
|
||||
"system_message = SystemMessage(content=\"You are a search assistant.\")\n",
|
||||
"prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)\n",
|
||||
"\n",
|
||||
"# Agent\n",
|
||||
"search_agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)\n",
|
||||
"agent_executor = AgentExecutor(agent=search_agent, tools=tools, verbose=False)\n",
|
||||
"\n",
|
||||
"# Run\n",
|
||||
"agent_executor.run(\"How many people live in canada as of 2023?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "27842380",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Great, we have created a simple search agent with a tool!\n",
|
||||
"\n",
|
||||
"Note that we use an agent executor, which is the runtime for an agent. \n",
|
||||
"\n",
|
||||
"This is what calls the agent and executes the actions it chooses. \n",
|
||||
"\n",
|
||||
"Pseudocode for this runtime is below:\n",
|
||||
"```\n",
|
||||
"next_action = agent.get_action(...)\n",
|
||||
"while next_action != AgentFinish:\n",
|
||||
" observation = run(next_action)\n",
|
||||
" next_action = agent.get_action(..., next_action, observation)\n",
|
||||
"return next_action\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"While this may seem simple, there are several complexities this runtime handles for you, including:\n",
|
||||
"\n",
|
||||
"* Handling cases where the agent selects a non-existent tool\n",
|
||||
"* Handling cases where the tool errors\n",
|
||||
"* Handling cases where the agent produces output that cannot be parsed into a tool invocation\n",
|
||||
"* Logging and observability at all levels (agent decisions, tool calls) either to stdout or LangSmith.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0b93c7d0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Memory \n",
|
||||
"\n",
|
||||
"### Short-term memory\n",
|
||||
"\n",
|
||||
"Of course, `memory` is needed to enable conversation / persistence of information.\n",
|
||||
"\n",
|
||||
"LangChain has many options for [short-term memory](/docs/modules/memory/types/), which are frequently used in [chat](/docs/modules/memory/adding_memory.html). \n",
|
||||
"\n",
|
||||
"They can be [employed with agents](/docs/modules/memory/agent_with_memory) too.\n",
|
||||
"\n",
|
||||
"`ConversationBufferMemory` is a popular choice for short-term memory.\n",
|
||||
"\n",
|
||||
"We set `MEMORY_KEY`, which can be referenced by the prompt later.\n",
|
||||
"\n",
|
||||
"Now, let's add memory to our agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "1d291015",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Memory \n",
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"MEMORY_KEY = \"chat_history\"\n",
|
||||
"memory = ConversationBufferMemory(memory_key=MEMORY_KEY, return_messages=True)\n",
|
||||
"\n",
|
||||
"# Prompt w/ placeholder for memory\n",
|
||||
"from langchain.schema import SystemMessage\n",
|
||||
"from langchain.agents import OpenAIFunctionsAgent\n",
|
||||
"from langchain.prompts import MessagesPlaceholder\n",
|
||||
"system_message = SystemMessage(content=\"You are a search assistant tasked with using Serpapi to answer questions.\")\n",
|
||||
"prompt = OpenAIFunctionsAgent.create_prompt(\n",
|
||||
" system_message=system_message,\n",
|
||||
" extra_prompt_messages=[MessagesPlaceholder(variable_name=MEMORY_KEY)]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Agent\n",
|
||||
"search_agent_memory = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt, memory=memory)\n",
|
||||
"agent_executor_memory = AgentExecutor(agent=search_agent_memory, tools=tools, memory=memory, verbose=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "b4b2249a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'As of August 2023, the estimated population of Canada is approximately 38,781,291 people.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_executor_memory.run(\"How many people live in Canada as of August, 2023?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "4d31b0cf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'As of August 2023, the largest province in Canada is Ontario, with a population of over 15 million people.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_executor_memory.run(\"What is the population of its largest provence as of August, 2023?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3606c32a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Looking at the [trace](https://smith.langchain.com/public/4425a131-ec90-4aaa-acd8-5b880c7452a3/r), we can what is happening:\n",
|
||||
"\n",
|
||||
"* The chat history is passed to the LLMs\n",
|
||||
"* This gives context to `its` in `What is the population of its largest provence as of August, 2023?`\n",
|
||||
"* The LLM generates a function call to the search tool\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"function_call:\n",
|
||||
" name: Search\n",
|
||||
" arguments: |-\n",
|
||||
" {\n",
|
||||
" \"query\": \"population of largest province in Canada as of August 2023\"\n",
|
||||
" }\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"* The search is executed\n",
|
||||
"* The results frum search are passed back to the LLM for synthesis into an answer\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "384e37f8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Long-term memory \n",
|
||||
"\n",
|
||||
"Vectorstores are great options for long-term memory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "1489746c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import faiss\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"from langchain.docstore import InMemoryDocstore\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"embedding_size = 1536\n",
|
||||
"embeddings_model = OpenAIEmbeddings()\n",
|
||||
"index = faiss.IndexFlatL2(embedding_size)\n",
|
||||
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9668ef5d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Going deeper \n",
|
||||
"\n",
|
||||
"* Explore projects using long-term memory, such as [autonomous agents](/docs/use_cases/autonomous_agents/autonomous_agents)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43fe2bb3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tools \n",
|
||||
"\n",
|
||||
"As mentioned above, LangChain has [many tools](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/load_tools.py) for Agents that we can load easily.\n",
|
||||
"\n",
|
||||
"We can also define [custom tools](/docs/modules/agents/tools/custom_tools). For example, here is a search tool.\n",
|
||||
"\n",
|
||||
"* The `Tool` dataclass wraps functions that accept a single string input and returns a string output.\n",
|
||||
"* `return_direct` determines whether to return the tool's output directly. \n",
|
||||
"* Setting this to `True` means that after the tool is called, the `AgentExecutor` will stop looping."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"id": "7357e496",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import Tool, tool\n",
|
||||
"from langchain.utilities import GoogleSearchAPIWrapper\n",
|
||||
"search = GoogleSearchAPIWrapper()\n",
|
||||
"search_tool = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Search\",\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to answer questions about current events\",\n",
|
||||
" return_direct=True,\n",
|
||||
" )\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c6ef5bfa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To make it easier to define custom tools, a `@tool` decorator is provided. \n",
|
||||
"\n",
|
||||
"This decorator can be used to quickly create a Tool from a simple function."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "b6308c69",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Tool\n",
|
||||
"@tool\n",
|
||||
"def get_word_length(word: str) -> int:\n",
|
||||
" \"\"\"Returns the length of a word.\"\"\"\n",
|
||||
" return len(word)\n",
|
||||
"word_length_tool = [get_word_length]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "83c104d7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Going deeper\n",
|
||||
"\n",
|
||||
"**Toolkits**\n",
|
||||
"\n",
|
||||
"* Toolkits are groups of tools needed to accomplish specific objectives.\n",
|
||||
"* [Here](/docs/integrations/toolkits/) are > 15 different agent toolkits (e.g., Gmail, Pandas, etc). \n",
|
||||
"\n",
|
||||
"Here is a simple way to think about agents vs the various chains covered in other docs:\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5eefe4a0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Agents\n",
|
||||
"\n",
|
||||
"There's a number of [action agent types](docs/modules/agents/agent_types/) available in LangChain.\n",
|
||||
"\n",
|
||||
"* [ReAct](/docs/modules/agents/agent_types/react.html): This is the most general purpose action agent using the [ReAct framework](https://arxiv.org/pdf/2205.00445.pdf), which can work with [Docstores](/docs/modules/agents/agent_types/react_docstore.html) or [Multi-tool Inputs](/docs/modules/agents/agent_types/structured_chat.html).\n",
|
||||
"* [OpenAI functions](/docs/modules/agents/agent_types/openai_functions_agent.html): Designed to work with OpenAI function-calling models.\n",
|
||||
"* [Conversational](/docs/modules/agents/agent_types/chat_conversation_agent.html): This agent is designed to be used in conversational settings\n",
|
||||
"* [Self-ask with search](/docs/modules/agents/agent_types/self_ask_with_search.html): Designed to lookup factual answers to questions\n",
|
||||
"\n",
|
||||
"### OpenAI Functions agent\n",
|
||||
"\n",
|
||||
"As shown in Quickstart, let's continue with [`OpenAI functions` agent](/docs/modules/agents/agent_types/).\n",
|
||||
"\n",
|
||||
"This uses OpenAI models, which are fine-tuned to detect when a function should to be called.\n",
|
||||
"\n",
|
||||
"They will respond with the inputs that should be passed to the function.\n",
|
||||
"\n",
|
||||
"But, we can unpack it, first with a custom prompt:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "1c2deb4a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Memory\n",
|
||||
"MEMORY_KEY = \"chat_history\"\n",
|
||||
"memory = ConversationBufferMemory(memory_key=MEMORY_KEY, return_messages=True)\n",
|
||||
"\n",
|
||||
"# Prompt\n",
|
||||
"from langchain.schema import SystemMessage\n",
|
||||
"from langchain.agents import OpenAIFunctionsAgent\n",
|
||||
"system_message = SystemMessage(content=\"You are very powerful assistant, but bad at calculating lengths of words.\")\n",
|
||||
"prompt = OpenAIFunctionsAgent.create_prompt(\n",
|
||||
" system_message=system_message,\n",
|
||||
" extra_prompt_messages=[MessagesPlaceholder(variable_name=MEMORY_KEY)]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ee317a45",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define agent:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "460dab9b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Agent \n",
|
||||
"from langchain.agents import OpenAIFunctionsAgent\n",
|
||||
"agent = OpenAIFunctionsAgent(llm=llm, tools=word_length_tool, prompt=prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "184e6c23",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run agent:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "f4f27d37",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'There are 5 letters in the word \"educa\".'"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Run the executer, including short-term memory we created\n",
|
||||
"agent_executor = AgentExecutor(agent=agent, tools=word_length_tool, memory=memory, verbose=False)\n",
|
||||
"agent_executor.run(\"how many letters in the word educa?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4d9217e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### ReAct agent\n",
|
||||
"\n",
|
||||
"[ReAct](https://arxiv.org/abs/2210.03629) agents are another popular framework.\n",
|
||||
"\n",
|
||||
"There has been lots of work on [LLM reasoning](https://ai.googleblog.com/2022/05/language-models-perform-reasoning-via.html), such as chain-of-thought prompting.\n",
|
||||
"\n",
|
||||
"There also has been work on LLM action-taking to generate obervations, such as [Say-Can](https://say-can.github.io/).\n",
|
||||
"\n",
|
||||
"ReAct marries these two ideas:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"It uses a charecteristic `Thought`, `Action`, `Observation` [pattern in the output](https://lilianweng.github.io/posts/2023-06-23-agent/).\n",
|
||||
" \n",
|
||||
"We can use `initialize_agent` to create the ReAct agent from a list of available types [here](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/types.py):\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"* AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent\n",
|
||||
"* AgentType.REACT_DOCSTORE: ReActDocstoreAgent\n",
|
||||
"* AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent\n",
|
||||
"* AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent\n",
|
||||
"* AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent\n",
|
||||
"* AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent\n",
|
||||
"* AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION: StructuredChatAgent\n",
|
||||
"* AgentType.OPENAI_FUNCTIONS: OpenAIFunctionsAgent\n",
|
||||
"* AgentType.OPENAI_MULTI_FUNCTIONS: OpenAIMultiFunctionsAgent\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "85f033d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import AgentType\n",
|
||||
"from langchain.agents import initialize_agent\n",
|
||||
"MEMORY_KEY = \"chat_history\"\n",
|
||||
"memory = ConversationBufferMemory(memory_key=MEMORY_KEY, return_messages=True)\n",
|
||||
"react_agent = initialize_agent(search_tool, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7d05a26c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"react_agent(\"How many people live in Canada as of August, 2023?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b626dc5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"react_agent(\"What is the population of its largest provence as of August, 2023?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d4df0638",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"LangSmith can help us run diagnostics on the ReAct agent:\n",
|
||||
"\n",
|
||||
"The [ReAct agent](https://smith.langchain.com/public/3d8d0a15-d73f-44f3-9f81-037f7031c592/r) fails to pass chat history to LLM, gets wrong answer.\n",
|
||||
" \n",
|
||||
"The OAI functions agent does and [gets right answer](https://smith.langchain.com/public/4425a131-ec90-4aaa-acd8-5b880c7452a3/r), as shown above.\n",
|
||||
" \n",
|
||||
"Also the search tool result for [ReAct](https://smith.langchain.com/public/6473e608-fc9d-47c9-a8a4-2ef7f2801d82/r) is worse than [OAI](https://smith.langchain.com/public/4425a131-ec90-4aaa-acd8-5b880c7452a3/r/26b85fa9-e33a-4028-8650-1714f8b3db96).\n",
|
||||
"\n",
|
||||
"Collectivly, this tells us: carefully inspect Agent traces and tool outputs. \n",
|
||||
"\n",
|
||||
"As we saw with the [SQL use case](/docs/use_cases/sql), `ReAct agents` can be work very well for specific problems. \n",
|
||||
"\n",
|
||||
"But, as shown here, the result is degraded relative to what we see with the OpenAI agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5cde8f9a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Custom\n",
|
||||
"\n",
|
||||
"Let's peel it back even further to define our own action agent.\n",
|
||||
"\n",
|
||||
"We can [create a custom agent](/docs/modules/agents/how_to/custom_agent.html) to unpack the central pieces:\n",
|
||||
"\n",
|
||||
"* `Tools`: The tools the agent has available to use\n",
|
||||
"* `Agent`: decides which action to take"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"id": "3313f5cd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"The current population of Canada is 38,808,843 as of Tuesday, August 1, 2023, based on Worldometer elaboration of the latest United Nations data 1. Canada 2023\\xa0... Mar 22, 2023 ... Record-high population growth in the year 2022. Canada's population was estimated at 39,566,248 on January 1, 2023, after a record population\\xa0... Jun 19, 2023 ... As of June 16, 2023, there are now 40 million Canadians! This is a historic milestone for Canada and certainly cause for celebration. It is also\\xa0... Jun 28, 2023 ... Canada's population was estimated at 39,858,480 on April 1, 2023, an increase of 292,232 people (+0.7%) from January 1, 2023. The main driver of population growth is immigration, and to a lesser extent, natural growth. Demographics of Canada · Population pyramid of Canada in 2023. May 2, 2023 ... On January 1, 2023, Canada's population was estimated to be 39,566,248, following an unprecedented increase of 1,050,110 people between January\\xa0... Canada ranks 37th by population among countries of the world, comprising about 0.5% of the world's total, with over 40.0 million Canadians as of 2023. The current population of Canada in 2023 is 38,781,291, a 0.85% increase from 2022. The population of Canada in 2022 was 38,454,327, a 0.78% increase from 2021. Whether a given sub-nation is a province or a territory depends upon how its power and authority are derived. Provinces were given their power by the\\xa0... Jun 28, 2023 ... Index to the latest information from the Census of Population. ... 2023. Census in Brief: Multilingualism of Canadian households\\xa0...\""
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import List, Tuple, Any, Union\n",
|
||||
"from langchain.schema import AgentAction, AgentFinish\n",
|
||||
"from langchain.agents import Tool, AgentExecutor, BaseSingleActionAgent\n",
|
||||
"\n",
|
||||
"class FakeAgent(BaseSingleActionAgent):\n",
|
||||
" \"\"\"Fake Custom Agent.\"\"\"\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def input_keys(self):\n",
|
||||
" return [\"input\"]\n",
|
||||
"\n",
|
||||
" def plan(\n",
|
||||
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
|
||||
" ) -> Union[AgentAction, AgentFinish]:\n",
|
||||
" \"\"\"Given input, decided what to do.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" intermediate_steps: Steps the LLM has taken to date,\n",
|
||||
" along with observations\n",
|
||||
" **kwargs: User inputs.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" Action specifying what tool to use.\n",
|
||||
" \"\"\"\n",
|
||||
" return AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\")\n",
|
||||
"\n",
|
||||
" async def aplan(\n",
|
||||
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
|
||||
" ) -> Union[AgentAction, AgentFinish]:\n",
|
||||
" \"\"\"Given input, decided what to do.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" intermediate_steps: Steps the LLM has taken to date,\n",
|
||||
" along with observations\n",
|
||||
" **kwargs: User inputs.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" Action specifying what tool to use.\n",
|
||||
" \"\"\"\n",
|
||||
" return AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\")\n",
|
||||
" \n",
|
||||
"fake_agent = FakeAgent()\n",
|
||||
"fake_agent_executor = AgentExecutor.from_agent_and_tools(agent=fake_agent, \n",
|
||||
" tools=search_tool, \n",
|
||||
" verbose=False)\n",
|
||||
"\n",
|
||||
"fake_agent_executor.run(\"How many people live in canada as of 2023?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1335f0c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Runtime\n",
|
||||
"\n",
|
||||
"The `AgentExecutor` class is the main agent runtime supported by LangChain. \n",
|
||||
"\n",
|
||||
"However, there are other, more experimental runtimes for `autonomous_agents`:\n",
|
||||
" \n",
|
||||
"* Plan-and-execute Agent\n",
|
||||
"* Baby AGI\n",
|
||||
"* Auto GPT\n",
|
||||
"\n",
|
||||
"Explore more about:\n",
|
||||
"\n",
|
||||
"* [`Simulation agents`](/docs/modules/agents/agent_use_cases/agent_simulations): Designed for role-play often in simulated enviorment (e.g., Generative Agents, CAMEL).\n",
|
||||
"* [`Autonomous agents`](/docs/modules/agents/agent_use_cases/autonomous_agents): Designed for indepdent execution towards long term goals (e.g., BabyAGI, Auto-GPT).\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -249,7 +249,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Before Width: | Height: | Size: 769 KiB After Width: | Height: | Size: 769 KiB |
|
Before Width: | Height: | Size: 769 KiB After Width: | Height: | Size: 769 KiB |
|
Before Width: | Height: | Size: 369 KiB After Width: | Height: | Size: 369 KiB |
@@ -245,7 +245,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
@@ -25,8 +25,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install gpt4all\n",
|
||||
"! pip install chromadb"
|
||||
"pip install gpt4all chromadb"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -157,7 +156,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install llama-cpp-python"
|
||||
"pip install llama-cpp-python"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -736,7 +735,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,342 +0,0 @@
|
||||
---
|
||||
sidebar_position: -1
|
||||
---
|
||||
|
||||
# QA over Documents
|
||||
|
||||
## Use case
|
||||
Suppose you have some text documents (PDF, blog, Notion pages, etc.) and want to ask questions related to the contents of those documents. LLMs, given their proficiency in understanding text, are a great tool for this.
|
||||
|
||||
In this walkthrough we'll go over how to build a question-answering over documents application using LLMs. Two very related use cases which we cover elsewhere are:
|
||||
- [QA over structured data](/docs/use_cases/tabular) (e.g., SQL)
|
||||
- [QA over code](/docs/use_cases/code) (e.g., Python)
|
||||
|
||||

|
||||
|
||||
## Overview
|
||||
The pipeline for converting raw unstructured data into a QA chain looks like this:
|
||||
1. `Loading`: First we need to load our data. Unstructured data can be loaded from many sources. Use the [LangChain integration hub](https://integrations.langchain.com/) to browse the full set of loaders.
|
||||
Each loader returns data as a LangChain [`Document`](https://docs.langchain.com/docs/components/schema/document).
|
||||
2. `Splitting`: [Text splitters](/docs/modules/data_connection/document_transformers/) break `Documents` into splits of specified size
|
||||
3. `Storage`: Storage (e.g., often a [vectorstore](/docs/modules/data_connection/vectorstores/)) will house [and often embed](https://www.pinecone.io/learn/vector-embeddings/) the splits
|
||||
4. `Retrieval`: The app retrieves splits from storage (e.g., often [with similar embeddings](https://www.pinecone.io/learn/k-nearest-neighbor/) to the input question)
|
||||
5. `Generation`: An [LLM](/docs/modules/model_io/models/llms/) produces an answer using a prompt that includes the question and the retrieved data
|
||||
6. `Conversation` (Extension): Hold a multi-turn conversation by adding [Memory](/docs/modules/memory/) to your QA chain.
|
||||
|
||||

|
||||
|
||||
## Quickstart
|
||||
To give you a sneak preview, the above pipeline can be all be wrapped in a single object: `VectorstoreIndexCreator`. Suppose we want a QA app over this [blog post](https://lilianweng.github.io/posts/2023-06-23-agent/). We can create this in a few lines of code:
|
||||
|
||||
First set environment variables and install packages:
|
||||
```bash
|
||||
pip install openai chromadb
|
||||
export OPENAI_API_KEY="..."
|
||||
```
|
||||
|
||||
Then run:
|
||||
```python
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.indexes import VectorstoreIndexCreator
|
||||
|
||||
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
|
||||
index = VectorstoreIndexCreator().from_loaders([loader])
|
||||
```
|
||||
|
||||
And now ask your questions:
|
||||
```python
|
||||
index.query("What is Task Decomposition?")
|
||||
```
|
||||
|
||||
' Task decomposition is a technique used to break down complex tasks into smaller and simpler steps. It can be done using LLM with simple prompting, task-specific instructions, or human inputs. Tree of Thoughts (Yao et al. 2023) is an example of a task decomposition technique that explores multiple reasoning possibilities at each step and generates multiple thoughts per step, creating a tree structure.'
|
||||
|
||||
Ok, but what's going on under the hood, and how could we customize this for our specific use case? For that, let's take a look at how we can construct this pipeline piece by piece.
|
||||
|
||||
## Step 1. Load
|
||||
|
||||
Specify a `DocumentLoader` to load in your unstructured data as `Documents`. A `Document` is a piece of text (the `page_content`) and associated metadata.
|
||||
|
||||
```python
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
|
||||
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
|
||||
data = loader.load()
|
||||
```
|
||||
|
||||
### Go deeper
|
||||
- Browse the > 120 data loader integrations [here](https://integrations.langchain.com/).
|
||||
- See further documentation on loaders [here](/docs/modules/data_connection/document_loaders/).
|
||||
|
||||
## Step 2. Split
|
||||
|
||||
Split the `Document` into chunks for embedding and vector storage.
|
||||
|
||||
```python
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 0)
|
||||
all_splits = text_splitter.split_documents(data)
|
||||
```
|
||||
|
||||
### Go deeper
|
||||
|
||||
- `DocumentSplitters` are just one type of the more generic `DocumentTransformers`, which can all be useful in this preprocessing step.
|
||||
- See further documentation on transformers [here](/docs/modules/data_connection/document_transformers/).
|
||||
- `Context-aware splitters` keep the location ("context") of each split in the original `Document`:
|
||||
- [Markdown files](/docs/use_cases/question_answering/document-context-aware-QA)
|
||||
- [Code (py or js)](/docs/modules/data_connection/document_loaders/integrations/source_code)
|
||||
- [Documents](/docs/modules/data_connection/document_loaders/integrations/grobid)
|
||||
|
||||
## Step 3. Store
|
||||
|
||||
To be able to look up our document splits, we first need to store them where we can later look them up.
|
||||
The most common way to do this is to embed the contents of each document then store the embedding and document in a vector store, with the embedding being used to index the document.
|
||||
|
||||
```python
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())
|
||||
```
|
||||
|
||||
### Go deeper
|
||||
- Browse the > 40 vectorstores integrations [here](https://integrations.langchain.com/).
|
||||
- See further documentation on vectorstores [here](/docs/modules/data_connection/vectorstores/).
|
||||
- Browse the > 30 text embedding integrations [here](https://integrations.langchain.com/).
|
||||
- See further documentation on embedding models [here](/docs/modules/data_connection/text_embedding/).
|
||||
|
||||
Here are Steps 1-3:
|
||||
|
||||

|
||||
|
||||
## Step 4. Retrieve
|
||||
|
||||
Retrieve relevant splits for any question using [similarity search](https://www.pinecone.io/learn/what-is-similarity-search/).
|
||||
|
||||
```python
|
||||
question = "What are the approaches to Task Decomposition?"
|
||||
docs = vectorstore.similarity_search(question)
|
||||
len(docs)
|
||||
```
|
||||
|
||||
4
|
||||
|
||||
### Go deeper
|
||||
|
||||
Vectorstores are commonly used for retrieval, but they are not the only option. For example, SVMs (see thread [here](https://twitter.com/karpathy/status/1647025230546886658?s=20)) can also be used.
|
||||
|
||||
LangChain [has many retrievers](/docs/modules/data_connection/retrievers/) including, but not limited to, vectorstores. All retrievers implement a common method `get_relevant_documents()` (and its asynchronous variant `aget_relevant_documents()`).
|
||||
|
||||
```python
|
||||
from langchain.retrievers import SVMRetriever
|
||||
|
||||
svm_retriever = SVMRetriever.from_documents(all_splits,OpenAIEmbeddings())
|
||||
docs_svm=svm_retriever.get_relevant_documents(question)
|
||||
len(docs_svm)
|
||||
```
|
||||
|
||||
4
|
||||
|
||||
Some common ways to improve on vector similarity search include:
|
||||
- `MultiQueryRetriever` [generates variants of the input question](/docs/modules/data_connection/retrievers/MultiQueryRetriever) to improve retrieval.
|
||||
- `Max marginal relevance` selects for [relevance and diversity](https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf) among the retrieved documents.
|
||||
- Documents can be filtered during retrieval using [`metadata` filters](/docs/use_cases/question_answering/how_to/document-context-aware-QA).
|
||||
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
|
||||
|
||||
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectorstore.as_retriever(),
|
||||
llm=ChatOpenAI(temperature=0))
|
||||
unique_docs = retriever_from_llm.get_relevant_documents(query=question)
|
||||
len(unique_docs)
|
||||
```
|
||||
|
||||
INFO:langchain.retrievers.multi_query:Generated queries: ['1. How can Task Decomposition be approached?', '2. What are the different methods for Task Decomposition?', '3. What are the various approaches to decomposing tasks?']
|
||||
5
|
||||
|
||||
## Step 5. Generate
|
||||
|
||||
Distill the retrieved documents into an answer using an LLM/Chat model (e.g., `gpt-3.5-turbo`) with `RetrievalQA` chain.
|
||||
|
||||
```python
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
||||
qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever())
|
||||
qa_chain({"query": question})
|
||||
```
|
||||
|
||||
{
|
||||
'query': 'What are the approaches to Task Decomposition?',
|
||||
'result': 'The approaches to task decomposition include:\n\n1. Simple prompting: This approach involves using simple prompts or questions to guide the agent in breaking down a task into smaller subgoals. For example, the agent can be prompted with "Steps for XYZ" and asked to list the subgoals for achieving XYZ.\n\n2. Task-specific instructions: In this approach, task-specific instructions are provided to the agent to guide the decomposition process. For example, if the task is to write a novel, the agent can be instructed to "Write a story outline" as a subgoal.\n\n3. Human inputs: This approach involves incorporating human inputs in the task decomposition process. Humans can provide guidance, feedback, and suggestions to help the agent break down complex tasks into manageable subgoals.\n\nThese approaches aim to enable efficient handling of complex tasks by breaking them down into smaller, more manageable parts.'
|
||||
}
|
||||
|
||||
Note, you can pass in an `LLM` or a `ChatModel` (like we did here) to the `RetrievalQA` chain.
|
||||
|
||||
### Go deeper
|
||||
|
||||
#### Choosing LLMs
|
||||
- Browse the > 55 LLM and chat model integrations [here](https://integrations.langchain.com/).
|
||||
- See further documentation on LLMs and chat models [here](/docs/modules/model_io/models/).
|
||||
- Use local LLMS: The popularity of [PrivateGPT](https://github.com/imartinez/privateGPT) and [GPT4All](https://github.com/nomic-ai/gpt4all) underscore the importance of running LLMs locally.
|
||||
Using `GPT4All` is as simple as [downloading the binary]((/docs/integrations/llms/gpt4all)) and then:
|
||||
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.chains import RetrievalQA
|
||||
|
||||
llm = GPT4All(model="/Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin",max_tokens=2048)
|
||||
qa_chain = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever())
|
||||
|
||||
#### Customizing the prompt
|
||||
|
||||
The prompt in `RetrievalQA` chain can be easily customized.
|
||||
|
||||
```python
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
template = """Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
Use three sentences maximum and keep the answer as concise as possible.
|
||||
Always say "thanks for asking!" at the end of the answer.
|
||||
{context}
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
|
||||
|
||||
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
||||
qa_chain = RetrievalQA.from_chain_type(
|
||||
llm,
|
||||
retriever=vectorstore.as_retriever(),
|
||||
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
|
||||
)
|
||||
result = qa_chain({"query": question})
|
||||
result["result"]
|
||||
```
|
||||
|
||||
'The approaches to Task Decomposition are (1) using simple prompting by LLM, (2) using task-specific instructions, and (3) with human inputs. Thanks for asking!'
|
||||
|
||||
|
||||
#### Return source documents
|
||||
|
||||
The full set of retrieved documents used for answer distillation can be returned using `return_source_documents=True`.
|
||||
|
||||
```python
|
||||
from langchain.chains import RetrievalQA
|
||||
|
||||
qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),
|
||||
return_source_documents=True)
|
||||
result = qa_chain({"query": question})
|
||||
print(len(result['source_documents']))
|
||||
result['source_documents'][0]
|
||||
```
|
||||
|
||||
4
|
||||
Document(page_content='Task decomposition can be done (1) by LLM with simple prompting like "Steps for XYZ.\\n1.", "What are the subgoals for achieving XYZ?", (2) by using task-specific instructions; e.g. "Write a story outline." for writing a novel, or (3) with human inputs.', metadata={'source': 'https://lilianweng.github.io/posts/2023-06-23-agent/', 'title': "LLM Powered Autonomous Agents | Lil'Log", 'description': 'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as AutoGPT, GPT-Engineer and BabyAGI, serve as inspiring examples. The potentiality of LLM extends beyond generating well-written copies, stories, essays and programs; it can be framed as a powerful general problem solver.\nAgent System Overview In a LLM-powered autonomous agent system, LLM functions as the agent’s brain, complemented by several key components:', 'language': 'en'})
|
||||
|
||||
|
||||
|
||||
#### Return citations
|
||||
|
||||
Answer citations can be returned using `RetrievalQAWithSourcesChain`.
|
||||
|
||||
|
||||
```python
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
|
||||
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm,retriever=vectorstore.as_retriever())
|
||||
|
||||
result = qa_chain({"question": question})
|
||||
result
|
||||
```
|
||||
|
||||
{
|
||||
'question': 'What are the approaches to Task Decomposition?',
|
||||
'answer': 'The approaches to Task Decomposition include (1) using LLM with simple prompting, (2) using task-specific instructions, and (3) incorporating human inputs.\n',
|
||||
'sources': 'https://lilianweng.github.io/posts/2023-06-23-agent/'
|
||||
}
|
||||
|
||||
#### Customizing retrieved document processing
|
||||
|
||||
Retrieved documents can be fed to an LLM for answer distillation in a few different ways.
|
||||
|
||||
`stuff`, `refine`, `map-reduce`, and `map-rerank` chains for passing documents to an LLM prompt are well summarized [here](/docs/modules/chains/document/).
|
||||
|
||||
`stuff` is commonly used because it simply "stuffs" all retrieved documents into the prompt.
|
||||
|
||||
The [load_qa_chain](/docs/use_cases/question_answering/how_to/question_answering.html) is an easy way to pass documents to an LLM using these various approaches (e.g., see `chain_type`).
|
||||
|
||||
|
||||
```python
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
|
||||
chain = load_qa_chain(llm, chain_type="stuff")
|
||||
chain({"input_documents": unique_docs, "question": question},return_only_outputs=True)
|
||||
```
|
||||
|
||||
{'output_text': 'The approaches to task decomposition include (1) using simple prompting to break down tasks into subgoals, (2) providing task-specific instructions to guide the decomposition process, and (3) incorporating human inputs for task decomposition.'}
|
||||
|
||||
We can also pass the `chain_type` to `RetrievalQA`.
|
||||
|
||||
|
||||
```python
|
||||
qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),
|
||||
chain_type="stuff")
|
||||
result = qa_chain({"query": question})
|
||||
```
|
||||
|
||||
In summary, the user can choose the desired level of abstraction for QA:
|
||||
|
||||

|
||||
|
||||
## Step 6. Converse (Extension)
|
||||
|
||||
To hold a conversation, a chain needs to be able to refer to past interactions. Chain `Memory` allows us to do this. To keep chat history, we can specify a Memory buffer to track the conversation inputs / outputs.
|
||||
|
||||
```python
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||||
```
|
||||
|
||||
The `ConversationalRetrievalChain` uses chat in the `Memory buffer`.
|
||||
|
||||
```python
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
|
||||
retriever = vectorstore.as_retriever()
|
||||
chat = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory)
|
||||
```
|
||||
|
||||
```python
|
||||
result = chat({"question": "What are some of the main ideas in self-reflection?"})
|
||||
result['answer']
|
||||
```
|
||||
|
||||
"Some of the main ideas in self-reflection include:\n1. Iterative improvement: Self-reflection allows autonomous agents to improve by refining past action decisions and correcting mistakes.\n2. Trial and error: Self-reflection is crucial in real-world tasks where trial and error are inevitable.\n3. Two-shot examples: Self-reflection is created by showing pairs of failed trajectories and ideal reflections for guiding future changes in the plan.\n4. Working memory: Reflections are added to the agent's working memory, up to three, to be used as context for querying.\n5. Performance evaluation: Self-reflection involves continuously reviewing and analyzing actions, self-criticizing behavior, and reflecting on past decisions and strategies to refine approaches.\n6. Efficiency: Self-reflection encourages being smart and efficient, aiming to complete tasks in the least number of steps."
|
||||
|
||||
The Memory buffer has context to resolve `"it"` ("self-reflection") in the below question.
|
||||
|
||||
```python
|
||||
result = chat({"question": "How does the Reflexion paper handle it?"})
|
||||
result['answer']
|
||||
```
|
||||
|
||||
"The Reflexion paper handles self-reflection by showing two-shot examples to the Learning Language Model (LLM). Each example consists of a failed trajectory and an ideal reflection that guides future changes in the agent's plan. These reflections are then added to the agent's working memory, up to a maximum of three, to be used as context for querying the LLM. This allows the agent to iteratively improve its reasoning skills by refining past action decisions and correcting previous mistakes."
|
||||
|
||||
### Go deeper
|
||||
|
||||
The [documentation](/docs/use_cases/question_answering/how_to/chat_vector_db) on `ConversationalRetrievalChain` offers a few extensions, such as streaming and source documents.
|
||||
|
||||
|
||||
## Further reading
|
||||
- Check out the [How to](/docs/use_cases/question_answer/how_to/) section for all the variations of chains that can be used for QA over docs in different settings.
|
||||
- Check out the [Integrations-specific](/docs/use_cases/question_answer/integrations/) section for chains that use specific integrations.
|
||||
@@ -0,0 +1,686 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5151afed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Question Answering\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/langchain-ai/langchain/blob/master/docs/extras/use_cases/question_answering/qa.ipynb)\n",
|
||||
"\n",
|
||||
"## Use case\n",
|
||||
"Suppose you have some text documents (PDF, blog, Notion pages, etc.) and want to ask questions related to the contents of those documents. LLMs, given their proficiency in understanding text, are a great tool for this.\n",
|
||||
"\n",
|
||||
"In this walkthrough we'll go over how to build a question-answering over documents application using LLMs. Two very related use cases which we cover elsewhere are:\n",
|
||||
"- [QA over structured data](/docs/use_cases/sql) (e.g., SQL)\n",
|
||||
"- [QA over code](/docs/use_cases/code) (e.g., Python)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"The pipeline for converting raw unstructured data into a QA chain looks like this:\n",
|
||||
"1. `Loading`: First we need to load our data. Unstructured data can be loaded from many sources. Use the [LangChain integration hub](https://integrations.langchain.com/) to browse the full set of loaders.\n",
|
||||
"Each loader returns data as a LangChain [`Document`](/docs/components/schema/document).\n",
|
||||
"2. `Splitting`: [Text splitters](/docs/modules/data_connection/document_transformers/) break `Documents` into splits of specified size\n",
|
||||
"3. `Storage`: Storage (e.g., often a [vectorstore](/docs/modules/data_connection/vectorstores/)) will house [and often embed](https://www.pinecone.io/learn/vector-embeddings/) the splits\n",
|
||||
"4. `Retrieval`: The app retrieves splits from storage (e.g., often [with similar embeddings](https://www.pinecone.io/learn/k-nearest-neighbor/) to the input question)\n",
|
||||
"5. `Generation`: An [LLM](/docs/modules/model_io/models/llms/) produces an answer using a prompt that includes the question and the retrieved data\n",
|
||||
"6. `Conversation` (Extension): Hold a multi-turn conversation by adding [Memory](/docs/modules/memory/) to your QA chain.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Quickstart\n",
|
||||
"\n",
|
||||
"To give you a sneak preview, the above pipeline can be all be wrapped in a single object: `VectorstoreIndexCreator`. Suppose we want a QA app over this [blog post](https://lilianweng.github.io/posts/2023-06-23-agent/). We can create this in a few lines of code. First set environment variables and install packages:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e14b744b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pip install openai chromadb\n",
|
||||
"\n",
|
||||
"# Set env var OPENAI_API_KEY or load from a .env file\n",
|
||||
"# import dotenv\n",
|
||||
"\n",
|
||||
"# dotenv.load_env()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "046cefc0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import WebBaseLoader\n",
|
||||
"from langchain.indexes import VectorstoreIndexCreator\n",
|
||||
"\n",
|
||||
"loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
|
||||
"index = VectorstoreIndexCreator().from_loaders([loader])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f4bf8740",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Task decomposition is a technique used to break down complex tasks into smaller and simpler steps. It can be done using LLM with simple prompting, task-specific instructions, or with human inputs. Tree of Thoughts (Yao et al. 2023) is an extension of Chain of Thought (Wei et al. 2022) which explores multiple reasoning possibilities at each step.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"index.query(\"What is Task Decomposition?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8224aad6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ok, but what's going on under the hood, and how could we customize this for our specific use case? For that, let's take a look at how we can construct this pipeline piece by piece."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ba5daed6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 1. Load\n",
|
||||
"\n",
|
||||
"Specify a `DocumentLoader` to load in your unstructured data as `Documents`. A `Document` is a piece of text (the `page_content`) and associated metadata."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "cf4d5c72",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import WebBaseLoader\n",
|
||||
"\n",
|
||||
"loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
|
||||
"data = loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fd2cc9a7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Go deeper\n",
|
||||
"- Browse the > 120 data loader integrations [here](https://integrations.langchain.com/).\n",
|
||||
"- See further documentation on loaders [here](/docs/modules/data_connection/document_loaders/).\n",
|
||||
"\n",
|
||||
"## Step 2. Split\n",
|
||||
"\n",
|
||||
"Split the `Document` into chunks for embedding and vector storage."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "4b11c01d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 0)\n",
|
||||
"all_splits = text_splitter.split_documents(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0a33bd4d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Go deeper\n",
|
||||
"\n",
|
||||
"- `DocumentSplitters` are just one type of the more generic `DocumentTransformers`, which can all be useful in this preprocessing step.\n",
|
||||
"- See further documentation on transformers [here](/docs/modules/data_connection/document_transformers/).\n",
|
||||
"- `Context-aware splitters` keep the location (\"context\") of each split in the original `Document`:\n",
|
||||
" - [Markdown files](/docs/use_cases/question_answering/how_to/document-context-aware-QA)\n",
|
||||
" - [Code (py or js)](docs/integrations/document_loaders/source_code)\n",
|
||||
" - [Documents](/docs/integrations/document_loaders/grobid)\n",
|
||||
"\n",
|
||||
"## Step 3. Store\n",
|
||||
"\n",
|
||||
"To be able to look up our document splits, we first need to store them where we can later look them up.\n",
|
||||
"The most common way to do this is to embed the contents of each document then store the embedding and document in a vector store, with the embedding being used to index the document."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "e9c302c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"\n",
|
||||
"vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dc6f22b0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Go deeper\n",
|
||||
"- Browse the > 40 vectorstores integrations [here](https://integrations.langchain.com/).\n",
|
||||
"- See further documentation on vectorstores [here](/docs/modules/data_connection/vectorstores/).\n",
|
||||
"- Browse the > 30 text embedding integrations [here](https://integrations.langchain.com/).\n",
|
||||
"- See further documentation on embedding models [here](/docs/modules/data_connection/text_embedding/).\n",
|
||||
"\n",
|
||||
" Here are Steps 1-3:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Step 4. Retrieve\n",
|
||||
"\n",
|
||||
"Retrieve relevant splits for any question using [similarity search](https://www.pinecone.io/learn/what-is-similarity-search/)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "e2c26b7d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"4"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"What are the approaches to Task Decomposition?\"\n",
|
||||
"docs = vectorstore.similarity_search(question)\n",
|
||||
"len(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5d5a113b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Go deeper\n",
|
||||
"\n",
|
||||
"Vectorstores are commonly used for retrieval, but they are not the only option. For example, SVMs (see thread [here](https://twitter.com/karpathy/status/1647025230546886658?s=20)) can also be used.\n",
|
||||
"\n",
|
||||
"LangChain [has many retrievers](/docs/modules/data_connection/retrievers/) including, but not limited to, vectorstores. All retrievers implement a common method `get_relevant_documents()` (and its asynchronous variant `aget_relevant_documents()`)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "c901eaee",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"4"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.retrievers import SVMRetriever\n",
|
||||
"\n",
|
||||
"svm_retriever = SVMRetriever.from_documents(all_splits,OpenAIEmbeddings())\n",
|
||||
"docs_svm=svm_retriever.get_relevant_documents(question)\n",
|
||||
"len(docs_svm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "69de3d54",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Some common ways to improve on vector similarity search include:\n",
|
||||
"- `MultiQueryRetriever` [generates variants of the input question](/docs/modules/data_connection/retrievers/MultiQueryRetriever) to improve retrieval.\n",
|
||||
"- `Max marginal relevance` selects for [relevance and diversity](https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf) among the retrieved documents.\n",
|
||||
"- Documents can be filtered during retrieval using [`metadata` filters](/docs/use_cases/question_answering/how_to/document-context-aware-QA)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "c690f01a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:langchain.retrievers.multi_query:Generated queries: ['1. How can Task Decomposition be approached?', '2. What are the different methods for Task Decomposition?', '3. What are the various approaches to decomposing tasks?']\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"4"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.retrievers.multi_query import MultiQueryRetriever\n",
|
||||
"\n",
|
||||
"logging.basicConfig()\n",
|
||||
"logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)\n",
|
||||
"\n",
|
||||
"retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectorstore.as_retriever(),\n",
|
||||
" llm=ChatOpenAI(temperature=0))\n",
|
||||
"unique_docs = retriever_from_llm.get_relevant_documents(query=question)\n",
|
||||
"len(unique_docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "415d6824",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 5. Generate\n",
|
||||
"\n",
|
||||
"Distill the retrieved documents into an answer using an LLM/Chat model (e.g., `gpt-3.5-turbo`) with `RetrievalQA` chain.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "99fa1aec",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'query': 'What are the approaches to Task Decomposition?',\n",
|
||||
" 'result': 'There are three approaches to task decomposition:\\n\\n1. Using Language Model with simple prompting: This approach involves using a Language Model (LLM) with simple prompts like \"Steps for XYZ\" or \"What are the subgoals for achieving XYZ?\" to guide the task decomposition process.\\n\\n2. Using task-specific instructions: In this approach, task-specific instructions are provided to guide the task decomposition. For example, for the task of writing a novel, an instruction like \"Write a story outline\" can be given to help decompose the task into smaller subtasks.\\n\\n3. Human inputs: Task decomposition can also be done with the help of human inputs. This involves getting input and guidance from humans to break down a complex task into smaller, more manageable subtasks.'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import RetrievalQA\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0)\n",
|
||||
"qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever())\n",
|
||||
"qa_chain({\"query\": question})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f7d52c84",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note, you can pass in an `LLM` or a `ChatModel` (like we did here) to the `RetrievalQA` chain.\n",
|
||||
"\n",
|
||||
"### Go deeper\n",
|
||||
"\n",
|
||||
"#### Choosing LLMs\n",
|
||||
"- Browse the > 55 LLM and chat model integrations [here](https://integrations.langchain.com/).\n",
|
||||
"- See further documentation on LLMs and chat models [here](/docs/modules/model_io/models/).\n",
|
||||
"- Use local LLMS: The popularity of [PrivateGPT](https://github.com/imartinez/privateGPT) and [GPT4All](https://github.com/nomic-ai/gpt4all) underscore the importance of running LLMs locally.\n",
|
||||
"Using `GPT4All` is as simple as [downloading the binary]((/docs/integrations/llms/gpt4all)) and then:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "02d6c9dc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found model file at /Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"objc[61331]: Class GGMLMetalClass is implemented in both /Users/rlm/miniforge3/envs/llama/lib/python3.9/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libreplit-mainline-metal.dylib (0x2e3384208) and /Users/rlm/miniforge3/envs/llama/lib/python3.9/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libllamamodel-mainline-metal.dylib (0x2e37b0208). One of the two will be used. Which one is undefined.\n",
|
||||
"llama.cpp: using Metal\n",
|
||||
"llama.cpp: loading model from /Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin\n",
|
||||
"llama_model_load_internal: format = ggjt v3 (latest)\n",
|
||||
"llama_model_load_internal: n_vocab = 32001\n",
|
||||
"llama_model_load_internal: n_ctx = 2048\n",
|
||||
"llama_model_load_internal: n_embd = 5120\n",
|
||||
"llama_model_load_internal: n_mult = 256\n",
|
||||
"llama_model_load_internal: n_head = 40\n",
|
||||
"llama_model_load_internal: n_layer = 40\n",
|
||||
"llama_model_load_internal: n_rot = 128\n",
|
||||
"llama_model_load_internal: ftype = 2 (mostly Q4_0)\n",
|
||||
"llama_model_load_internal: n_ff = 13824\n",
|
||||
"llama_model_load_internal: n_parts = 1\n",
|
||||
"llama_model_load_internal: model size = 13B\n",
|
||||
"llama_model_load_internal: ggml ctx size = 0.09 MB\n",
|
||||
"llama_model_load_internal: mem required = 9031.71 MB (+ 1608.00 MB per state)\n",
|
||||
"llama_new_context_with_model: kv self size = 1600.00 MB\n",
|
||||
"ggml_metal_init: allocating\n",
|
||||
"ggml_metal_init: using MPS\n",
|
||||
"ggml_metal_init: loading '/Users/rlm/miniforge3/envs/llama/lib/python3.9/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/ggml-metal.metal'\n",
|
||||
"ggml_metal_init: loaded kernel_add 0x2bbbbc2f0\n",
|
||||
"ggml_metal_init: loaded kernel_mul 0x2bbbba840\n",
|
||||
"ggml_metal_init: loaded kernel_mul_row 0x2bb917dd0\n",
|
||||
"ggml_metal_init: loaded kernel_scale 0x2bb918150\n",
|
||||
"ggml_metal_init: loaded kernel_silu 0x2bb9184d0\n",
|
||||
"ggml_metal_init: loaded kernel_relu 0x2bb918850\n",
|
||||
"ggml_metal_init: loaded kernel_gelu 0x2bbbc3f10\n",
|
||||
"ggml_metal_init: loaded kernel_soft_max 0x2bbbc5840\n",
|
||||
"ggml_metal_init: loaded kernel_diag_mask_inf 0x2bbbc4c70\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_f16 0x2bbbc5fc0\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q4_0 0x2bbbc6720\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q4_1 0x2bb918c10\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q2_k 0x2bbbc51b0\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q3_k 0x2bbbc7630\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q4_k 0x2d4394e30\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q5_k 0x2bbbc7890\n",
|
||||
"ggml_metal_init: loaded kernel_get_rows_q6_k 0x2d4395210\n",
|
||||
"ggml_metal_init: loaded kernel_rms_norm 0x2bbbc8740\n",
|
||||
"ggml_metal_init: loaded kernel_norm 0x2bbbc8b30\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_f16_f32 0x2d4395470\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q4_0_f32 0x2d4395a70\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q4_1_f32 0x1242b1a00\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q2_k_f32 0x29f17d1c0\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q3_k_f32 0x2d4396050\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q4_k_f32 0x2bbbc98a0\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q5_k_f32 0x2bbbca4a0\n",
|
||||
"ggml_metal_init: loaded kernel_mul_mat_q6_k_f32 0x2bbbcae90\n",
|
||||
"ggml_metal_init: loaded kernel_rope 0x2bbbca700\n",
|
||||
"ggml_metal_init: loaded kernel_alibi_f32 0x2bbbcc6e0\n",
|
||||
"ggml_metal_init: loaded kernel_cpy_f32_f16 0x2bbbccf90\n",
|
||||
"ggml_metal_init: loaded kernel_cpy_f32_f32 0x2bbbcd900\n",
|
||||
"ggml_metal_init: loaded kernel_cpy_f16_f16 0x2bbbce1f0\n",
|
||||
"ggml_metal_init: recommendedMaxWorkingSetSize = 21845.34 MB\n",
|
||||
"ggml_metal_init: hasUnifiedMemory = true\n",
|
||||
"ggml_metal_init: maxTransferRate = built-in GPU\n",
|
||||
"ggml_metal_add_buffer: allocated 'data ' buffer, size = 6984.06 MB, ( 6984.45 / 21845.34)\n",
|
||||
"ggml_metal_add_buffer: allocated 'eval ' buffer, size = 1024.00 MB, ( 8008.45 / 21845.34)\n",
|
||||
"ggml_metal_add_buffer: allocated 'kv ' buffer, size = 1602.00 MB, ( 9610.45 / 21845.34)\n",
|
||||
"ggml_metal_add_buffer: allocated 'scr0 ' buffer, size = 512.00 MB, (10122.45 / 21845.34)\n",
|
||||
"ggml_metal_add_buffer: allocated 'scr1 ' buffer, size = 512.00 MB, (10634.45 / 21845.34)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import GPT4All\n",
|
||||
"from langchain.chains import RetrievalQA\n",
|
||||
"\n",
|
||||
"llm = GPT4All(model=\"/Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin\",max_tokens=2048)\n",
|
||||
"qa_chain = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fa82f437",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Customizing the prompt\n",
|
||||
"\n",
|
||||
"The prompt in `RetrievalQA` chain can be easily customized."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "e4fee704",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ggml_metal_free: deallocating\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The approaches to task decomposition include using LLM with simple prompting, task-specific instructions, or human inputs. Thanks for asking!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import RetrievalQA\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"template = \"\"\"Use the following pieces of context to answer the question at the end. \n",
|
||||
"If you don't know the answer, just say that you don't know, don't try to make up an answer. \n",
|
||||
"Use three sentences maximum and keep the answer as concise as possible. \n",
|
||||
"Always say \"thanks for asking!\" at the end of the answer. \n",
|
||||
"{context}\n",
|
||||
"Question: {question}\n",
|
||||
"Helpful Answer:\"\"\"\n",
|
||||
"QA_CHAIN_PROMPT = PromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0)\n",
|
||||
"qa_chain = RetrievalQA.from_chain_type(\n",
|
||||
" llm,\n",
|
||||
" retriever=vectorstore.as_retriever(),\n",
|
||||
" chain_type_kwargs={\"prompt\": QA_CHAIN_PROMPT}\n",
|
||||
")\n",
|
||||
"result = qa_chain({\"query\": question})\n",
|
||||
"result[\"result\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ff40e8db",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Return source documents\n",
|
||||
"\n",
|
||||
"The full set of retrieved documents used for answer distillation can be returned using `return_source_documents=True`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "60004293",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"4\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Document(page_content='Task decomposition can be done (1) by LLM with simple prompting like \"Steps for XYZ.\\\\n1.\", \"What are the subgoals for achieving XYZ?\", (2) by using task-specific instructions; e.g. \"Write a story outline.\" for writing a novel, or (3) with human inputs.', metadata={'description': 'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as AutoGPT, GPT-Engineer and BabyAGI, serve as inspiring examples. The potentiality of LLM extends beyond generating well-written copies, stories, essays and programs; it can be framed as a powerful general problem solver.\\nAgent System Overview In a LLM-powered autonomous agent system, LLM functions as the agent’s brain, complemented by several key components:', 'language': 'en', 'source': 'https://lilianweng.github.io/posts/2023-06-23-agent/', 'title': \"LLM Powered Autonomous Agents | Lil'Log\"})"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import RetrievalQA\n",
|
||||
"\n",
|
||||
"qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),\n",
|
||||
" return_source_documents=True)\n",
|
||||
"result = qa_chain({\"query\": question})\n",
|
||||
"print(len(result['source_documents']))\n",
|
||||
"result['source_documents'][0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1b600236",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Return citations\n",
|
||||
"\n",
|
||||
"Answer citations can be returned using `RetrievalQAWithSourcesChain`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "948f6d19",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'question': 'What are the approaches to Task Decomposition?',\n",
|
||||
" 'answer': 'The approaches to Task Decomposition include:\\n1. Using LLM with simple prompting, such as providing steps or subgoals for achieving a task.\\n2. Using task-specific instructions, such as providing a specific instruction like \"Write a story outline\" for writing a novel.\\n3. Using human inputs to decompose the task.\\nAnother approach is the Tree of Thoughts, which extends the Chain of Thought (CoT) technique by exploring multiple reasoning possibilities at each step and generating multiple thoughts per step, creating a tree structure. The search process can be BFS or DFS, and each state can be evaluated by a classifier or majority vote.\\nSources: https://lilianweng.github.io/posts/2023-06-23-agent/',\n",
|
||||
" 'sources': ''}"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import RetrievalQAWithSourcesChain\n",
|
||||
"\n",
|
||||
"qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm,retriever=vectorstore.as_retriever())\n",
|
||||
"\n",
|
||||
"result = qa_chain({\"question\": question})\n",
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "73d0b138",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Customizing retrieved document processing\n",
|
||||
"\n",
|
||||
"Retrieved documents can be fed to an LLM for answer distillation in a few different ways.\n",
|
||||
"\n",
|
||||
"`stuff`, `refine`, `map-reduce`, and `map-rerank` chains for passing documents to an LLM prompt are well summarized [here](/docs/modules/chains/document/).\n",
|
||||
" \n",
|
||||
"`stuff` is commonly used because it simply \"stuffs\" all retrieved documents into the prompt.\n",
|
||||
"\n",
|
||||
"The [load_qa_chain](/docs/use_cases/question_answering/how_to/question_answering.html) is an easy way to pass documents to an LLM using these various approaches (e.g., see `chain_type`)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "29aa139f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'output_text': 'The approaches to task decomposition mentioned in the provided context are:\\n\\n1. Chain of thought (CoT): This approach involves instructing the language model to \"think step by step\" and decompose complex tasks into smaller and simpler steps. It enhances model performance on complex tasks by utilizing more test-time computation.\\n\\n2. Tree of Thoughts: This approach extends CoT by exploring multiple reasoning possibilities at each step. It decomposes the problem into multiple thought steps and generates multiple thoughts per step, creating a tree structure. The search process can be BFS or DFS, and each state is evaluated by a classifier or majority vote.\\n\\n3. LLM with simple prompting: This approach involves using a language model with simple prompts like \"Steps for XYZ\" or \"What are the subgoals for achieving XYZ?\" to perform task decomposition.\\n\\n4. Task-specific instructions: This approach involves providing task-specific instructions to guide the language model in decomposing the task. For example, providing the instruction \"Write a story outline\" for the task of writing a novel.\\n\\n5. Human inputs: Task decomposition can also be done with human inputs, where humans provide guidance and input to break down the task into smaller subtasks.'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"\n",
|
||||
"chain = load_qa_chain(llm, chain_type=\"stuff\")\n",
|
||||
"chain({\"input_documents\": unique_docs, \"question\": question},return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a8cb8cd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also pass the `chain_type` to `RetrievalQA`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "f68574bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),\n",
|
||||
" chain_type=\"stuff\")\n",
|
||||
"result = qa_chain({\"query\": question})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b33aeb5f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In summary, the user can choose the desired level of abstraction for QA:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Step 6. Chat\n",
|
||||
"\n",
|
||||
"See our [use-case on chat](/docs/use_cases/chatbots) for detail on this!"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
from langchain_experimental.comprehend_moderation.amazon_comprehend_moderation import (
|
||||
AmazonComprehendModerationChain,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
||||
BaseModerationCallbackHandler,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
||||
BaseModerationActions,
|
||||
BaseModerationFilters,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
|
||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
|
||||
|
||||
__all__ = [
|
||||
"BaseModeration",
|
||||
"BaseModerationActions",
|
||||
"BaseModerationFilters",
|
||||
"ComprehendPII",
|
||||
"ComprehendIntent",
|
||||
"ComprehendToxicity",
|
||||
"BaseModerationCallbackHandler",
|
||||
"AmazonComprehendModerationChain",
|
||||
]
|
||||
@@ -0,0 +1,184 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation import (
|
||||
BaseModeration,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
||||
BaseModerationCallbackHandler,
|
||||
)
|
||||
from langchain_experimental.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
class AmazonComprehendModerationChain(Chain):
|
||||
"""A subclass of Chain, designed to apply moderation to LLMs."""
|
||||
|
||||
output_key: str = "output" #: :meta private:
|
||||
"""Key used to fetch/store the output in data containers. Defaults to `output`"""
|
||||
|
||||
input_key: str = "input" #: :meta private:
|
||||
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
|
||||
|
||||
moderation_config: Optional[Dict[str, Any]] = None
|
||||
"""Configuration settings for moderation"""
|
||||
|
||||
client: Optional[Any]
|
||||
"""boto3 client object for connection to Amazon Comprehend"""
|
||||
|
||||
region_name: Optional[str] = None
|
||||
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
|
||||
or region specified in ~/.aws/config in case it is not provided here.
|
||||
"""
|
||||
|
||||
credentials_profile_name: Optional[str] = None
|
||||
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||
has either access keys or role information specified.
|
||||
If not specified, the default credential profile or, if on an EC2 instance,
|
||||
credentials from IMDS will be used.
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
moderation_callback: Optional[BaseModerationCallbackHandler] = None
|
||||
"""Callback handler for moderation, this is different
|
||||
from regular callbacks which can be used in addition to this."""
|
||||
|
||||
unique_id: Optional[str] = None
|
||||
"""A unique id that can be used to identify or group a user or session"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates an Amazon Comprehend client
|
||||
|
||||
Args:
|
||||
values (Dict[str, Any]): A dictionary containing configuration values.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with the updated configuration values,
|
||||
including the Amazon Comprehend client.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If the 'boto3' package is not installed.
|
||||
ValueError: If there is an issue importing 'boto3' or loading
|
||||
AWS credentials.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
config = {
|
||||
"credentials_profile_name": "my-profile",
|
||||
"region_name": "us-west-2"
|
||||
}
|
||||
updated_config = create_client(config)
|
||||
comprehend_client = updated_config["client"]
|
||||
"""
|
||||
|
||||
if values.get("client") is not None:
|
||||
return values
|
||||
try:
|
||||
import boto3
|
||||
|
||||
if values.get("credentials_profile_name"):
|
||||
session = boto3.Session(profile_name=values["credentials_profile_name"])
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if values.get("region_name"):
|
||||
client_params["region_name"] = values["region_name"]
|
||||
|
||||
values["client"] = session.client("comprehend", **client_params)
|
||||
|
||||
return values
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""
|
||||
Returns a list of output keys.
|
||||
|
||||
This method defines the output keys that will be used to access the output
|
||||
values produced by the chain or function. It ensures that the specified keys
|
||||
are available to access the outputs.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of output keys.
|
||||
|
||||
Note:
|
||||
This method is considered private and may not be intended for direct
|
||||
external use.
|
||||
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""
|
||||
Returns a list of input keys expected by the prompt.
|
||||
|
||||
This method defines the input keys that the prompt expects in order to perform
|
||||
its processing. It ensures that the specified keys are available for providing
|
||||
input to the prompt.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of input keys.
|
||||
|
||||
Note:
|
||||
This method is considered private and may not be intended for direct
|
||||
external use.
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Executes the moderation process on the input text and returns the processed
|
||||
output.
|
||||
|
||||
This internal method performs the moderation process on the input text. It
|
||||
converts the input prompt value to plain text, applies the specified filters,
|
||||
and then converts the filtered output back to a suitable prompt value object.
|
||||
Additionally, it provides the option to log information about the run using
|
||||
the provided `run_manager`.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary containing input values
|
||||
run_manager: A run manager to handle run-related events. Default is None
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary containing the processed output of the
|
||||
moderation process.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is an error during the moderation process
|
||||
"""
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_text("Running AmazonComprehendModerationChain...\n")
|
||||
|
||||
moderation = BaseModeration(
|
||||
client=self.client,
|
||||
config=self.moderation_config,
|
||||
moderation_callback=self.moderation_callback,
|
||||
unique_id=self.unique_id,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
response = moderation.moderate(prompt=inputs[self.input_keys[0]])
|
||||
|
||||
return {self.output_key: response}
|
||||
@@ -0,0 +1,176 @@
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
|
||||
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
|
||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
|
||||
|
||||
|
||||
class BaseModeration:
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
moderation_callback: Optional[Any] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
):
|
||||
self.client = client
|
||||
self.config = config
|
||||
self.moderation_callback = moderation_callback
|
||||
self.unique_id = unique_id
|
||||
self.chat_message_index = 0
|
||||
self.run_manager = run_manager
|
||||
self.chain_id = str(uuid.uuid4())
|
||||
|
||||
def _convert_prompt_to_text(self, prompt: Any) -> str:
|
||||
input_text = str()
|
||||
|
||||
if isinstance(prompt, StringPromptValue):
|
||||
input_text = prompt.text
|
||||
elif isinstance(prompt, str):
|
||||
input_text = prompt
|
||||
elif isinstance(prompt, ChatPromptValue):
|
||||
"""
|
||||
We will just check the last message in the message Chain of a
|
||||
ChatPromptTemplate. The typical chronology is
|
||||
SystemMessage > HumanMessage > AIMessage and so on. However assuming
|
||||
that with every chat the chain is invoked we will only check the last
|
||||
message. This is assuming that all previous messages have been checked
|
||||
already. Only HumanMessage and AIMessage will be checked. We can perhaps
|
||||
loop through and take advantage of the additional_kwargs property in the
|
||||
HumanMessage and AIMessage schema to mark messages that have been moderated.
|
||||
However that means that this class could generate multiple text chunks
|
||||
and moderate() logics would need to be updated. This also means some
|
||||
complexity in re-constructing the prompt while keeping the messages in
|
||||
sequence.
|
||||
"""
|
||||
message = prompt.messages[-1]
|
||||
self.chat_message_index = len(prompt.messages) - 1
|
||||
if isinstance(message, HumanMessage):
|
||||
input_text = message.content
|
||||
|
||||
if isinstance(message, AIMessage):
|
||||
input_text = message.content
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
return input_text
|
||||
|
||||
def _convert_text_to_prompt(self, prompt: Any, text: str) -> Any:
|
||||
if isinstance(prompt, StringPromptValue):
|
||||
return StringPromptValue(text=text)
|
||||
elif isinstance(prompt, str):
|
||||
return text
|
||||
elif isinstance(prompt, ChatPromptValue):
|
||||
messages = prompt.messages
|
||||
message = messages[self.chat_message_index]
|
||||
|
||||
if isinstance(message, HumanMessage):
|
||||
messages[self.chat_message_index] = HumanMessage(
|
||||
content=text,
|
||||
example=message.example,
|
||||
additional_kwargs=message.additional_kwargs,
|
||||
)
|
||||
if isinstance(message, AIMessage):
|
||||
messages[self.chat_message_index] = AIMessage(
|
||||
content=text,
|
||||
example=message.example,
|
||||
additional_kwargs=message.additional_kwargs,
|
||||
)
|
||||
return ChatPromptValue(messages=messages)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
def _moderation_class(self, moderation_class: Any) -> Callable:
|
||||
return moderation_class(
|
||||
client=self.client,
|
||||
callback=self.moderation_callback,
|
||||
unique_id=self.unique_id,
|
||||
chain_id=self.chain_id,
|
||||
).validate
|
||||
|
||||
def _log_message_for_verbose(self, message: str) -> None:
|
||||
if self.run_manager:
|
||||
self.run_manager.on_text(message)
|
||||
|
||||
def moderate(self, prompt: Any) -> str:
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
|
||||
ModerationIntentionError,
|
||||
ModerationPiiError,
|
||||
ModerationToxicityError,
|
||||
)
|
||||
|
||||
try:
|
||||
# convert prompt to text
|
||||
input_text = self._convert_prompt_to_text(prompt=prompt)
|
||||
output_text = str()
|
||||
# perform moderation
|
||||
if self.config is None:
|
||||
# In absence of config Action will default to STOP only
|
||||
self._log_message_for_verbose("Running pii validation...\n")
|
||||
pii_validate = self._moderation_class(moderation_class=ComprehendPII)
|
||||
output_text = pii_validate(prompt_value=input_text)
|
||||
|
||||
self._log_message_for_verbose("Running toxicity validation...\n")
|
||||
toxicity_validate = self._moderation_class(
|
||||
moderation_class=ComprehendToxicity
|
||||
)
|
||||
output_text = toxicity_validate(prompt_value=output_text)
|
||||
|
||||
self._log_message_for_verbose("Running intent validation...\n")
|
||||
intent_validate = self._moderation_class(
|
||||
moderation_class=ComprehendIntent
|
||||
)
|
||||
output_text = intent_validate(prompt_value=output_text)
|
||||
else:
|
||||
filter_functions = {
|
||||
"pii": ComprehendPII,
|
||||
"toxicity": ComprehendToxicity,
|
||||
"intent": ComprehendIntent,
|
||||
}
|
||||
filters = self.config["filters"]
|
||||
for _filter in filters:
|
||||
filter_name = f"{_filter}"
|
||||
if filter_name in filter_functions:
|
||||
self._log_message_for_verbose(
|
||||
f"Running {filter_name} Validation...\n"
|
||||
)
|
||||
validation_fn = self._moderation_class(
|
||||
moderation_class=filter_functions[filter_name]
|
||||
)
|
||||
input_text = input_text if not output_text else output_text
|
||||
output_text = validation_fn(
|
||||
prompt_value=input_text,
|
||||
config=self.config[filter_name]
|
||||
if filter_name in self.config
|
||||
else None,
|
||||
)
|
||||
# convert text to prompt and return
|
||||
return self._convert_text_to_prompt(prompt=prompt, text=output_text)
|
||||
|
||||
except ModerationPiiError as e:
|
||||
self._log_message_for_verbose(f"Found PII content..stopping..\n{str(e)}\n")
|
||||
raise e
|
||||
except ModerationToxicityError as e:
|
||||
self._log_message_for_verbose(
|
||||
f"Found Toxic content..stopping..\n{str(e)}\n"
|
||||
)
|
||||
raise e
|
||||
except ModerationIntentionError as e:
|
||||
self._log_message_for_verbose(
|
||||
f"Found Harmful intention..stopping..\n{str(e)}\n"
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,64 @@
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
class BaseModerationCallbackHandler:
|
||||
def __init__(self) -> None:
|
||||
if (
|
||||
self._is_method_unchanged(
|
||||
BaseModerationCallbackHandler.on_after_pii, self.on_after_pii
|
||||
)
|
||||
and self._is_method_unchanged(
|
||||
BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
|
||||
)
|
||||
and self._is_method_unchanged(
|
||||
BaseModerationCallbackHandler.on_after_intent, self.on_after_intent
|
||||
)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Subclasses must override at least one of on_after_pii(), "
|
||||
"on_after_toxicity(), or on_after_intent() functions."
|
||||
)
|
||||
|
||||
def _is_method_unchanged(
|
||||
self, base_method: Callable, derived_method: Callable
|
||||
) -> bool:
|
||||
return base_method.__qualname__ == derived_method.__qualname__
|
||||
|
||||
async def on_after_pii(
|
||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run after PII validation is complete."""
|
||||
raise NotImplementedError("Subclasses should implement this async method.")
|
||||
|
||||
async def on_after_toxicity(
|
||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run after Toxicity validation is complete."""
|
||||
raise NotImplementedError("Subclasses should implement this async method.")
|
||||
|
||||
async def on_after_intent(
|
||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run after Toxicity validation is complete."""
|
||||
raise NotImplementedError("Subclasses should implement this async method.")
|
||||
|
||||
@property
|
||||
def pii_callback(self) -> bool:
|
||||
return (
|
||||
self.on_after_pii.__func__ # type: ignore
|
||||
is not BaseModerationCallbackHandler.on_after_pii
|
||||
)
|
||||
|
||||
@property
|
||||
def toxicity_callback(self) -> bool:
|
||||
return (
|
||||
self.on_after_toxicity.__func__ # type: ignore
|
||||
is not BaseModerationCallbackHandler.on_after_toxicity
|
||||
)
|
||||
|
||||
@property
|
||||
def intent_callback(self) -> bool:
|
||||
return (
|
||||
self.on_after_intent.__func__ # type: ignore
|
||||
is not BaseModerationCallbackHandler.on_after_intent
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BaseModerationActions(Enum):
|
||||
STOP = 1
|
||||
ALLOW = 2
|
||||
|
||||
|
||||
class BaseModerationFilters(str, Enum):
|
||||
PII = "pii"
|
||||
TOXICITY = "toxicity"
|
||||
INTENT = "intent"
|
||||
@@ -0,0 +1,43 @@
|
||||
class ModerationPiiError(Exception):
|
||||
"""Exception raised if PII entities are detected.
|
||||
|
||||
Attributes:
|
||||
message -- explanation of the error
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str = "The prompt contains PII entities and cannot be processed"
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ModerationToxicityError(Exception):
|
||||
"""Exception raised if Toxic entities are detected.
|
||||
|
||||
Attributes:
|
||||
message -- explanation of the error
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str = "The prompt contains toxic content and cannot be processed"
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ModerationIntentionError(Exception):
|
||||
"""Exception raised if Intention entities are detected.
|
||||
|
||||
Attributes:
|
||||
message -- explanation of the error
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = (
|
||||
"The prompt indicates an un-desired intent and " "cannot be processed"
|
||||
),
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -0,0 +1,101 @@
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||
ModerationIntentionError,
|
||||
)
|
||||
|
||||
|
||||
class ComprehendIntent:
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
callback: Optional[Any] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
chain_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.moderation_beacon = {
|
||||
"moderation_chain_id": chain_id,
|
||||
"moderation_type": "Intent",
|
||||
"moderation_status": "LABELS_NOT_FOUND",
|
||||
}
|
||||
self.callback = callback
|
||||
self.unique_id = unique_id
|
||||
|
||||
def _get_arn(self) -> str:
|
||||
region_name = self.client.meta.region_name
|
||||
service = "comprehend"
|
||||
intent_endpoint = "document-classifier-endpoint/prompt-intent"
|
||||
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
|
||||
|
||||
def validate(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Check and validate the intent of the given prompt text.
|
||||
|
||||
Args:
|
||||
comprehend_client: Comprehend client for intent classification
|
||||
prompt_value (str): The input text to be checked for unintended intent
|
||||
config (Dict[str, Any]): Configuration settings for intent checks
|
||||
|
||||
Raises:
|
||||
ValueError: If unintended intent is found in the prompt text based
|
||||
on the specified threshold.
|
||||
|
||||
Returns:
|
||||
str: The input prompt_value.
|
||||
|
||||
Note:
|
||||
This function checks the intent of the provided prompt text using
|
||||
Comprehend's classify_document API and raises an error if unintended
|
||||
intent is detected with a score above the specified threshold.
|
||||
|
||||
"""
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
||||
BaseModerationActions,
|
||||
)
|
||||
|
||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
||||
action = (
|
||||
config.get("action", BaseModerationActions.STOP)
|
||||
if config
|
||||
else BaseModerationActions.STOP
|
||||
)
|
||||
intent_found = False
|
||||
|
||||
if action == BaseModerationActions.ALLOW:
|
||||
warnings.warn(
|
||||
"You have allowed content with Harmful content."
|
||||
"Defaulting to STOP action..."
|
||||
)
|
||||
action = BaseModerationActions.STOP
|
||||
|
||||
endpoint_arn = self._get_arn()
|
||||
response = self.client.classify_document(
|
||||
Text=prompt_value, EndpointArn=endpoint_arn
|
||||
)
|
||||
|
||||
if self.callback and self.callback.intent_callback:
|
||||
self.moderation_beacon["moderation_input"] = prompt_value
|
||||
self.moderation_beacon["moderation_output"] = response
|
||||
|
||||
for class_result in response["Classes"]:
|
||||
if (
|
||||
class_result["Score"] >= threshold
|
||||
and class_result["Name"] == "UNDESIRED_PROMPT"
|
||||
):
|
||||
intent_found = True
|
||||
break
|
||||
|
||||
if self.callback and self.callback.intent_callback:
|
||||
if intent_found:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
|
||||
)
|
||||
if intent_found:
|
||||
raise ModerationIntentionError
|
||||
return prompt_value
|
||||
@@ -0,0 +1,173 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||
ModerationPiiError,
|
||||
)
|
||||
|
||||
|
||||
class ComprehendPII:
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
callback: Optional[Any] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
chain_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.moderation_beacon = {
|
||||
"moderation_chain_id": chain_id,
|
||||
"moderation_type": "PII",
|
||||
"moderation_status": "LABELS_NOT_FOUND",
|
||||
}
|
||||
self.callback = callback
|
||||
self.unique_id = unique_id
|
||||
|
||||
def validate(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
||||
BaseModerationActions,
|
||||
)
|
||||
|
||||
if config:
|
||||
action = config.get("action", BaseModerationActions.STOP)
|
||||
if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
|
||||
raise ValueError("Action can either be stop or allow")
|
||||
|
||||
return (
|
||||
self._contains_pii(prompt_value=prompt_value, config=config)
|
||||
if action == BaseModerationActions.STOP
|
||||
else self._detect_pii(prompt_value=prompt_value, config=config)
|
||||
)
|
||||
else:
|
||||
return self._contains_pii(prompt_value=prompt_value)
|
||||
|
||||
def _contains_pii(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Checks for Personally Identifiable Information (PII) labels above a
|
||||
specified threshold.
|
||||
|
||||
Args:
|
||||
prompt_value (str): The input text to be checked for PII labels.
|
||||
config (Dict[str, Any]): Configuration for PII check and actions.
|
||||
|
||||
Returns:
|
||||
str: the original prompt
|
||||
|
||||
Note:
|
||||
- The provided client should be initialized with valid AWS credentials.
|
||||
"""
|
||||
pii_identified = self.client.contains_pii_entities(
|
||||
Text=prompt_value, LanguageCode="en"
|
||||
)
|
||||
|
||||
if self.callback and self.callback.pii_callback:
|
||||
self.moderation_beacon["moderation_input"] = prompt_value
|
||||
self.moderation_beacon["moderation_output"] = pii_identified
|
||||
|
||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
||||
pii_labels = config.get("labels", []) if config else []
|
||||
pii_found = False
|
||||
for entity in pii_identified["Labels"]:
|
||||
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
|
||||
entity["Score"] >= threshold and not pii_labels
|
||||
):
|
||||
pii_found = True
|
||||
break
|
||||
|
||||
if self.callback and self.callback.pii_callback:
|
||||
if pii_found:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||||
)
|
||||
if pii_found:
|
||||
raise ModerationPiiError
|
||||
return prompt_value
|
||||
|
||||
def _detect_pii(self, prompt_value: str, config: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Detects and handles Personally Identifiable Information (PII) entities in the
|
||||
given prompt text using Amazon Comprehend's detect_pii_entities API. The
|
||||
function provides options to redact or stop processing based on the identified
|
||||
PII entities and a provided configuration.
|
||||
|
||||
Args:
|
||||
prompt_value (str): The input text to be checked for PII entities.
|
||||
config (Dict[str, Any]): A configuration specifying how to handle
|
||||
PII entities.
|
||||
|
||||
Returns:
|
||||
str: The processed prompt text with redacted PII entities or raised
|
||||
exceptions.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prompt contains configured PII entities for
|
||||
stopping processing.
|
||||
|
||||
Note:
|
||||
- If PII is not found in the prompt, the original prompt is returned.
|
||||
- The client should be initialized with valid AWS credentials.
|
||||
"""
|
||||
pii_identified = self.client.detect_pii_entities(
|
||||
Text=prompt_value, LanguageCode="en"
|
||||
)
|
||||
|
||||
if self.callback and self.callback.pii_callback:
|
||||
self.moderation_beacon["moderation_input"] = prompt_value
|
||||
self.moderation_beacon["moderation_output"] = pii_identified
|
||||
|
||||
if (pii_identified["Entities"]) == []:
|
||||
if self.callback and self.callback.pii_callback:
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||||
)
|
||||
return prompt_value
|
||||
|
||||
pii_found = False
|
||||
if not config and pii_identified["Entities"]:
|
||||
for entity in pii_identified["Entities"]:
|
||||
if entity["Score"] >= 0.5:
|
||||
pii_found = True
|
||||
break
|
||||
|
||||
if self.callback and self.callback.pii_callback:
|
||||
if pii_found:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||||
)
|
||||
if pii_found:
|
||||
raise ModerationPiiError
|
||||
else:
|
||||
threshold = config.get("threshold", 0.5) # type: ignore
|
||||
pii_labels = config.get("labels", []) # type: ignore
|
||||
mask_marker = config.get("mask_character", "*") # type: ignore
|
||||
pii_found = False
|
||||
|
||||
for entity in pii_identified["Entities"]:
|
||||
if (
|
||||
pii_labels
|
||||
and entity["Type"] in pii_labels
|
||||
and entity["Score"] >= threshold
|
||||
) or (not pii_labels and entity["Score"] >= threshold):
|
||||
pii_found = True
|
||||
char_offset_begin = entity["BeginOffset"]
|
||||
char_offset_end = entity["EndOffset"]
|
||||
prompt_value = (
|
||||
prompt_value[:char_offset_begin]
|
||||
+ mask_marker * (char_offset_end - char_offset_begin)
|
||||
+ prompt_value[char_offset_end:]
|
||||
)
|
||||
|
||||
if self.callback and self.callback.pii_callback:
|
||||
if pii_found:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||||
)
|
||||
|
||||
return prompt_value
|
||||
@@ -0,0 +1,209 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||
ModerationToxicityError,
|
||||
)
|
||||
|
||||
|
||||
class ComprehendToxicity:
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
callback: Optional[Any] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
chain_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.moderation_beacon = {
|
||||
"moderation_chain_id": chain_id,
|
||||
"moderation_type": "Toxicity",
|
||||
"moderation_status": "LABELS_NOT_FOUND",
|
||||
}
|
||||
self.callback = callback
|
||||
self.unique_id = unique_id
|
||||
|
||||
def _toxicity_init_validate(self, max_size: int) -> Any:
|
||||
"""
|
||||
Validate and initialize toxicity processing configuration.
|
||||
|
||||
Args:
|
||||
max_size (int): Maximum sentence size defined in the configuration object.
|
||||
|
||||
Raises:
|
||||
Exception: If the maximum sentence size exceeds the 5KB limit.
|
||||
|
||||
Note:
|
||||
This function ensures that the NLTK punkt tokenizer is downloaded if not
|
||||
already present.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if max_size > 1024 * 5:
|
||||
raise Exception("The sentence length should not exceed 5KB.")
|
||||
try:
|
||||
nltk = importlib.import_module("nltk")
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
return nltk
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import nltk python package. "
|
||||
"Please install it with `pip install nltk`."
|
||||
)
|
||||
except LookupError:
|
||||
nltk.download("punkt")
|
||||
|
||||
def _split_paragraph(
|
||||
self, prompt_value: str, max_size: int = 1024 * 4
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Split a paragraph into chunks of sentences, respecting the maximum size limit.
|
||||
|
||||
Args:
|
||||
paragraph (str): The input paragraph to be split into chunks
|
||||
max_size (int, optional): The maximum size limit in bytes for each chunk
|
||||
Defaults to 1024.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: A list of chunks, where each chunk is a list of sentences
|
||||
|
||||
Note:
|
||||
This function validates the maximum sentence size based on service limits
|
||||
using the 'toxicity_init_validate' function. It uses the NLTK sentence
|
||||
tokenizer to split the paragraph into sentences.
|
||||
|
||||
"""
|
||||
|
||||
# validate max. sentence size based on Service limits
|
||||
nltk = self._toxicity_init_validate(max_size)
|
||||
|
||||
sentences = nltk.sent_tokenize(prompt_value)
|
||||
|
||||
chunks = []
|
||||
current_chunk = [] # type: ignore
|
||||
current_size = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_size = len(sentence.encode("utf-8"))
|
||||
|
||||
# If adding a new sentence exceeds max_size or
|
||||
# current_chunk has 10 sentences, start a new chunk
|
||||
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
|
||||
if current_chunk: # Avoid appending empty chunks
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
current_size = 0
|
||||
|
||||
current_chunk.append(sentence)
|
||||
current_size += sentence_size
|
||||
|
||||
# Add any remaining sentences
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def validate(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Check the toxicity of a given text prompt using AWS Comprehend service
|
||||
and apply actions based on configuration.
|
||||
|
||||
Args:
|
||||
prompt_value (str): The text content to be checked for toxicity.
|
||||
config (Dict[str, Any]): Configuration for toxicity checks and actions.
|
||||
|
||||
Returns:
|
||||
str: The original prompt_value if allowed or no toxicity found.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prompt contains toxic labels and cannot be
|
||||
processed based on the configuration.
|
||||
"""
|
||||
|
||||
chunks = self._split_paragraph(prompt_value=prompt_value)
|
||||
for sentence_list in chunks:
|
||||
segments = [{"Text": sentence} for sentence in sentence_list]
|
||||
response = self.client.detect_toxic_content(
|
||||
TextSegments=segments, LanguageCode="en"
|
||||
)
|
||||
if self.callback and self.callback.toxicity_callback:
|
||||
self.moderation_beacon["moderation_input"] = segments # type: ignore
|
||||
self.moderation_beacon["moderation_output"] = response
|
||||
|
||||
if config:
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
|
||||
BaseModerationActions,
|
||||
)
|
||||
|
||||
toxicity_found = False
|
||||
action = config.get("action", BaseModerationActions.STOP)
|
||||
if action not in [
|
||||
BaseModerationActions.STOP,
|
||||
BaseModerationActions.ALLOW,
|
||||
]:
|
||||
raise ValueError("Action can either be stop or allow")
|
||||
|
||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
||||
toxicity_labels = config.get("labels", []) if config else []
|
||||
|
||||
if action == BaseModerationActions.STOP:
|
||||
for item in response["ResultList"]:
|
||||
for label in item["Labels"]:
|
||||
if (
|
||||
label
|
||||
and (
|
||||
not toxicity_labels
|
||||
or label["Name"] in toxicity_labels
|
||||
)
|
||||
and label["Score"] >= threshold
|
||||
):
|
||||
toxicity_found = True
|
||||
break
|
||||
|
||||
if action == BaseModerationActions.ALLOW:
|
||||
if not toxicity_labels:
|
||||
warnings.warn(
|
||||
"You have allowed toxic content without specifying "
|
||||
"any toxicity labels."
|
||||
)
|
||||
else:
|
||||
for item in response["ResultList"]:
|
||||
for label in item["Labels"]:
|
||||
if (
|
||||
label["Name"] in toxicity_labels
|
||||
and label["Score"] >= threshold
|
||||
):
|
||||
toxicity_found = True
|
||||
break
|
||||
|
||||
if self.callback and self.callback.toxicity_callback:
|
||||
if toxicity_found:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_toxicity(
|
||||
self.moderation_beacon, self.unique_id
|
||||
)
|
||||
)
|
||||
if toxicity_found:
|
||||
raise ModerationToxicityError
|
||||
else:
|
||||
if response["ResultList"]:
|
||||
detected_toxic_labels = list()
|
||||
for item in response["ResultList"]:
|
||||
detected_toxic_labels.extend(item["Labels"])
|
||||
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
|
||||
if self.callback and self.callback.toxicity_callback:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_toxicity(
|
||||
self.moderation_beacon, self.unique_id
|
||||
)
|
||||
)
|
||||
raise ModerationToxicityError
|
||||
|
||||
return prompt_value
|
||||
@@ -33,6 +33,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -302,6 +303,14 @@ class RedisSemanticCache(BaseCache):
|
||||
|
||||
# TODO - implement a TTL policy in Redis
|
||||
|
||||
DEFAULT_SCHEMA = {
|
||||
"content_key": "prompt",
|
||||
"text": [
|
||||
{"name": "prompt"},
|
||||
],
|
||||
"extra": [{"name": "return_val"}, {"name": "llm_string"}],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
||||
):
|
||||
@@ -349,12 +358,14 @@ class RedisSemanticCache(BaseCache):
|
||||
embedding=self.embedding,
|
||||
index_name=index_name,
|
||||
redis_url=self.redis_url,
|
||||
schema=cast(Dict, self.DEFAULT_SCHEMA),
|
||||
)
|
||||
except ValueError:
|
||||
redis = RedisVectorstore(
|
||||
embedding_function=self.embedding.embed_query,
|
||||
embedding=self.embedding,
|
||||
index_name=index_name,
|
||||
redis_url=self.redis_url,
|
||||
index_schema=cast(Dict, self.DEFAULT_SCHEMA),
|
||||
)
|
||||
_embedding = self.embedding.embed_query(text="test")
|
||||
redis._create_index(dim=len(_embedding))
|
||||
@@ -374,17 +385,18 @@ class RedisSemanticCache(BaseCache):
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
generations = []
|
||||
generations: List = []
|
||||
# Read from a Hash
|
||||
results = llm_cache.similarity_search_limit_score(
|
||||
results = llm_cache.similarity_search(
|
||||
query=prompt,
|
||||
k=1,
|
||||
score_threshold=self.score_threshold,
|
||||
distance_threshold=self.score_threshold,
|
||||
)
|
||||
if results:
|
||||
for document in results:
|
||||
for text in document.metadata["return_val"]:
|
||||
generations.append(Generation(text=text))
|
||||
generations.extend(
|
||||
_load_generations_from_json(document.metadata["return_val"])
|
||||
)
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
@@ -402,11 +414,11 @@ class RedisSemanticCache(BaseCache):
|
||||
)
|
||||
return
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
# Write to vectorstore
|
||||
_dump_generations_to_json([g for g in return_val])
|
||||
metadata = {
|
||||
"llm_string": llm_string,
|
||||
"prompt": prompt,
|
||||
"return_val": [generation.text for generation in return_val],
|
||||
"return_val": _dump_generations_to_json([g for g in return_val]),
|
||||
}
|
||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.output import LLMResult
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
@@ -44,11 +44,18 @@ class LLMManagerMixin:
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
"""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
@@ -316,6 +323,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
|
||||
@@ -49,6 +49,7 @@ from langchain.schema import (
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langsmith import Client as LangSmithClient
|
||||
@@ -592,6 +593,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@@ -607,6 +610,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
chunk=chunk,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -655,6 +659,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@@ -667,6 +673,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token,
|
||||
chunk=chunk,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
|
||||
@@ -13,7 +13,12 @@ from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output import ChatGeneration, LLMResult
|
||||
from langchain.schema.output import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -123,6 +128,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@@ -135,11 +141,14 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
if chunk:
|
||||
event_kwargs["chunk"] = chunk
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "new_token",
|
||||
"time": datetime.utcnow(),
|
||||
"kwargs": {"token": token},
|
||||
"kwargs": event_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -120,9 +120,11 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
|
||||
def _split_sources(self, answer: str) -> Tuple[str, str]:
|
||||
"""Split sources from answer."""
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2]
|
||||
sources = re.split(r"\n", sources)[0]
|
||||
if re.search(r"SOURCES?[:\s]", answer, re.IGNORECASE):
|
||||
answer, sources = re.split(
|
||||
r"SOURCES?[:\s]|QUESTION:\s", answer, flags=re.IGNORECASE
|
||||
)[:2]
|
||||
sources = re.split(r"\n", sources)[0].strip()
|
||||
else:
|
||||
sources = ""
|
||||
return answer, sources
|
||||
|
||||
@@ -318,7 +318,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -398,7 +398,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -289,9 +289,10 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
for token in self.client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
||||
):
|
||||
yield GenerationChunk(text=token.completion)
|
||||
chunk = GenerationChunk(text=token.completion)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token.completion)
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@@ -324,9 +325,10 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
stream=True,
|
||||
**params,
|
||||
):
|
||||
yield GenerationChunk(text=token.completion)
|
||||
chunk = GenerationChunk(text=token.completion)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token.completion)
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
|
||||
@@ -297,6 +297,7 @@ class BaseOpenAI(BaseLLM):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
@@ -320,6 +321,7 @@ class BaseOpenAI(BaseLLM):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
@@ -825,9 +827,10 @@ class OpenAIChat(BaseLLM):
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
chunk = GenerationChunk(text=token)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@@ -842,9 +845,10 @@ class OpenAIChat(BaseLLM):
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
chunk = GenerationChunk(text=token)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token)
|
||||
await run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema.messages import get_buffer_string # noqa: 401
|
||||
|
||||
|
||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,16 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.client import Redis as RedisType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
|
||||
return np.array(array).astype(dtype).tobytes()
|
||||
|
||||
|
||||
class TokenEscaper:
|
||||
"""
|
||||
Escape punctuation within an input string.
|
||||
"""
|
||||
|
||||
# Characters that RediSearch requires us to escape during queries.
|
||||
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"
|
||||
|
||||
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||
if escape_chars_re:
|
||||
self.escaped_chars_re = escape_chars_re
|
||||
else:
|
||||
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
|
||||
|
||||
def escape(self, value: str) -> str:
|
||||
def escape_symbol(match: re.Match) -> str:
|
||||
value = match.group(0)
|
||||
return f"\\{value}"
|
||||
|
||||
return self.escaped_chars_re.sub(escape_symbol, value)
|
||||
|
||||
|
||||
def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||
"""Check if the correct Redis modules are installed."""
|
||||
installed_modules = client.module_list()
|
||||
installed_modules = {
|
||||
module[b"name"].decode("utf-8"): module for module in installed_modules
|
||||
}
|
||||
for module in required_modules:
|
||||
if module["name"] in installed_modules and int(
|
||||
installed_modules[module["name"]][b"ver"]
|
||||
) >= int(module["ver"]):
|
||||
return
|
||||
# otherwise raise error
|
||||
error_message = (
|
||||
"Redis cannot be used as a vector database without RediSearch >=2.4"
|
||||
"Please head to https://redis.io/docs/stack/search/quick_start/"
|
||||
"to know more about installing the RediSearch module within Redis Stack."
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def get_client(redis_url: str, **kwargs: Any) -> RedisType:
|
||||
|
||||
@@ -1,664 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.utilities.redis import get_client
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.client import Redis as RedisType
|
||||
from redis.commands.search.query import Query
|
||||
|
||||
|
||||
# required modules
|
||||
REDIS_REQUIRED_MODULES = [
|
||||
{"name": "search", "ver": 20400},
|
||||
{"name": "searchlight", "ver": 20400},
|
||||
]
|
||||
|
||||
# distance mmetrics
|
||||
REDIS_DISTANCE_METRICS = Literal["COSINE", "IP", "L2"]
|
||||
|
||||
|
||||
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||
"""Check if the correct Redis modules are installed."""
|
||||
installed_modules = client.module_list()
|
||||
installed_modules = {
|
||||
module[b"name"].decode("utf-8"): module for module in installed_modules
|
||||
}
|
||||
for module in required_modules:
|
||||
if module["name"] in installed_modules and int(
|
||||
installed_modules[module["name"]][b"ver"]
|
||||
) >= int(module["ver"]):
|
||||
return
|
||||
# otherwise raise error
|
||||
error_message = (
|
||||
"Redis cannot be used as a vector database without RediSearch >=2.4"
|
||||
"Please head to https://redis.io/docs/stack/search/quick_start/"
|
||||
"to know more about installing the RediSearch module within Redis Stack."
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def _check_index_exists(client: RedisType, index_name: str) -> bool:
|
||||
"""Check if Redis index exists."""
|
||||
try:
|
||||
client.ft(index_name).info()
|
||||
except: # noqa: E722
|
||||
logger.info("Index does not exist")
|
||||
return False
|
||||
logger.info("Index already exists")
|
||||
return True
|
||||
|
||||
|
||||
def _redis_key(prefix: str) -> str:
|
||||
"""Redis key schema for a given prefix."""
|
||||
return f"{prefix}:{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def _redis_prefix(index_name: str) -> str:
|
||||
"""Redis key prefix for a given index."""
|
||||
return f"doc:{index_name}"
|
||||
|
||||
|
||||
def _default_relevance_score(val: float) -> float:
|
||||
return 1 - val
|
||||
|
||||
|
||||
class Redis(VectorStore):
|
||||
"""`Redis` vector store.
|
||||
|
||||
To use, you should have the ``redis`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Redis(
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
index_name="my-index",
|
||||
embedding_function=embeddings.embed_query,
|
||||
)
|
||||
|
||||
To use a redis replication setup with multiple redis server and redis sentinels
|
||||
set "redis_url" to "redis+sentinel://" scheme. With this url format a path is
|
||||
needed holding the name of the redis service within the sentinels to get the
|
||||
correct redis server connection. The default service name is "mymaster".
|
||||
|
||||
An optional username or password is used for booth connections to the rediserver
|
||||
and the sentinel, different passwords for server and sentinel are not supported.
|
||||
And as another constraint only one sentinel instance can be given:
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
vectorstore = Redis(
|
||||
redis_url="redis+sentinel://username:password@sentinelhost:26379/mymaster/0"
|
||||
index_name="my-index",
|
||||
embedding_function=embeddings.embed_query,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str,
|
||||
index_name: str,
|
||||
embedding_function: Callable,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index_name = index_name
|
||||
try:
|
||||
redis_client = get_client(redis_url=redis_url, **kwargs)
|
||||
# check if redis has redisearch module installed
|
||||
_check_redis_module_exist(redis_client, REDIS_REQUIRED_MODULES)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Redis failed to connect: {e}")
|
||||
|
||||
self.client = redis_client
|
||||
self.content_key = content_key
|
||||
self.metadata_key = metadata_key
|
||||
self.vector_key = vector_key
|
||||
self.distance_metric = distance_metric
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Accept embedding object directly
|
||||
return None
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
if self.relevance_score_fn:
|
||||
return self.relevance_score_fn
|
||||
|
||||
if self.distance_metric == "COSINE":
|
||||
return self._cosine_relevance_score_fn
|
||||
elif self.distance_metric == "IP":
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_metric == "L2":
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
return _default_relevance_score
|
||||
|
||||
def _create_index(self, dim: int = 1536) -> None:
|
||||
try:
|
||||
from redis.commands.search.field import TextField, VectorField
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
# Check if index exists
|
||||
if not _check_index_exists(self.client, self.index_name):
|
||||
# Define schema
|
||||
schema = (
|
||||
TextField(name=self.content_key),
|
||||
TextField(name=self.metadata_key),
|
||||
VectorField(
|
||||
self.vector_key,
|
||||
"FLAT",
|
||||
{
|
||||
"TYPE": "FLOAT32",
|
||||
"DIM": dim,
|
||||
"DISTANCE_METRIC": self.distance_metric,
|
||||
},
|
||||
),
|
||||
)
|
||||
prefix = _redis_prefix(self.index_name)
|
||||
|
||||
# Create Redis Index
|
||||
self.client.ft(self.index_name).create_index(
|
||||
fields=schema,
|
||||
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embeddings: Optional[List[List[float]]] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add more texts to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
Defaults to None.
|
||||
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
|
||||
embeddings. Defaults to None.
|
||||
keys (List[str]) or ids (List[str]): Identifiers of entries.
|
||||
Defaults to None.
|
||||
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
|
||||
|
||||
Returns:
|
||||
List[str]: List of ids added to the vectorstore
|
||||
"""
|
||||
ids = []
|
||||
prefix = _redis_prefix(self.index_name)
|
||||
|
||||
# Get keys or ids from kwargs
|
||||
# Other vectorstores use ids
|
||||
keys_or_ids = kwargs.get("keys", kwargs.get("ids"))
|
||||
|
||||
# Write data to redis
|
||||
pipeline = self.client.pipeline(transaction=False)
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided values by default or fallback
|
||||
key = keys_or_ids[i] if keys_or_ids else _redis_key(prefix)
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
embedding = embeddings[i] if embeddings else self.embedding_function(text)
|
||||
pipeline.hset(
|
||||
key,
|
||||
mapping={
|
||||
self.content_key: text,
|
||||
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
|
||||
self.metadata_key: json.dumps(metadata),
|
||||
},
|
||||
)
|
||||
ids.append(key)
|
||||
|
||||
# Write batch
|
||||
if i % batch_size == 0:
|
||||
pipeline.execute()
|
||||
|
||||
# Cleanup final batch
|
||||
pipeline.execute()
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Returns the most similar indexed documents to the query text.
|
||||
|
||||
Args:
|
||||
query (str): The query text for which to find similar documents.
|
||||
k (int): The number of documents to return. Default is 4.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents that are most similar to the query text.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search_limit_score(
|
||||
self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Returns the most similar indexed documents to the query text within the
|
||||
score_threshold range.
|
||||
|
||||
Args:
|
||||
query (str): The query text for which to find similar documents.
|
||||
k (int): The number of documents to return. Default is 4.
|
||||
score_threshold (float): The minimum matching score required for a document
|
||||
to be considered a match. Defaults to 0.2.
|
||||
Because the similarity calculation algorithm is based on cosine
|
||||
similarity, the smaller the angle, the higher the similarity.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents that are most similar to the query text,
|
||||
including the match score for each document.
|
||||
|
||||
Note:
|
||||
If there are no documents that satisfy the score_threshold value,
|
||||
an empty list is returned.
|
||||
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||
return [doc for doc, score in docs_and_scores if score < score_threshold]
|
||||
|
||||
def _prepare_query(self, k: int) -> Query:
|
||||
try:
|
||||
from redis.commands.search.query import Query
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
# Prepare the Query
|
||||
hybrid_fields = "*"
|
||||
base_query = (
|
||||
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
|
||||
)
|
||||
return_fields = [self.metadata_key, self.content_key, "vector_score", "id"]
|
||||
return (
|
||||
Query(base_query)
|
||||
.return_fields(*return_fields)
|
||||
.sort_by("vector_score")
|
||||
.paging(0, k)
|
||||
.dialect(2)
|
||||
)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
# Creates embedding vector from user query
|
||||
embedding = self.embedding_function(query)
|
||||
|
||||
# Creates Redis query
|
||||
redis_query = self._prepare_query(k)
|
||||
|
||||
params_dict: Mapping[str, str] = {
|
||||
"vector": np.array(embedding) # type: ignore
|
||||
.astype(dtype=np.float32)
|
||||
.tobytes()
|
||||
}
|
||||
|
||||
# Perform vector search
|
||||
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
||||
|
||||
# Prepare document results
|
||||
docs_and_scores: List[Tuple[Document, float]] = []
|
||||
for result in results.docs:
|
||||
metadata = {**json.loads(result.metadata), "id": result.id}
|
||||
doc = Document(page_content=result.content, metadata=metadata)
|
||||
docs_and_scores.append((doc, float(result.vector_score)))
|
||||
return docs_and_scores
|
||||
|
||||
@classmethod
|
||||
def from_texts_return_keys(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
index_name: Optional[str] = None,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Redis, List[str]]:
|
||||
"""Create a Redis vectorstore from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a new index for the embeddings in Redis.
|
||||
3. Adds the documents to the newly created Redis index.
|
||||
4. Returns the keys of the newly created documents.
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redisearch, keys = RediSearch.from_texts_return_keys(
|
||||
texts,
|
||||
embeddings,
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
)
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
|
||||
# Name of the search index if not given
|
||||
if not index_name:
|
||||
index_name = uuid.uuid4().hex
|
||||
|
||||
# Create instance
|
||||
instance = cls(
|
||||
redis_url,
|
||||
index_name,
|
||||
embedding.embed_query,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
distance_metric=distance_metric,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embeddings over documents
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
|
||||
# Create the search index
|
||||
instance._create_index(dim=len(embeddings[0]))
|
||||
|
||||
# Add data to Redis
|
||||
keys = instance.add_texts(texts, metadatas, embeddings)
|
||||
return instance, keys
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Redis],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
index_name: Optional[str] = None,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
**kwargs: Any,
|
||||
) -> Redis:
|
||||
"""Create a Redis vectorstore from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a new index for the embeddings in Redis.
|
||||
3. Adds the documents to the newly created Redis index.
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redisearch = RediSearch.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
)
|
||||
"""
|
||||
instance, _ = cls.from_texts_return_keys(
|
||||
texts,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
index_name=index_name,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
**kwargs,
|
||||
)
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a Redis entry.
|
||||
|
||||
Args:
|
||||
ids: List of ids (keys) to delete.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the deletions were successful.
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("'ids' (keys)() were not provided.")
|
||||
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = get_client(redis_url=redis_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Your redis connected error: {e}")
|
||||
# Check if index exists
|
||||
try:
|
||||
client.delete(*ids)
|
||||
logger.info("Entries deleted")
|
||||
return True
|
||||
except: # noqa: E722
|
||||
# ids does not exist
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def drop_index(
|
||||
index_name: str,
|
||||
delete_documents: bool,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Drop a Redis search index.
|
||||
|
||||
Args:
|
||||
index_name (str): Name of the index to drop.
|
||||
delete_documents (bool): Whether to drop the associated documents.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the drop was successful.
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = get_client(redis_url=redis_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Your redis connected error: {e}")
|
||||
# Check if index exists
|
||||
try:
|
||||
client.ft(index_name).dropindex(delete_documents)
|
||||
logger.info("Drop index")
|
||||
return True
|
||||
except: # noqa: E722
|
||||
# Index not exist
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_existing_index(
|
||||
cls,
|
||||
embedding: Embeddings,
|
||||
index_name: str,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
**kwargs: Any,
|
||||
) -> Redis:
|
||||
"""Connect to an existing Redis index."""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = get_client(redis_url=redis_url, **kwargs)
|
||||
# check if redis has redisearch module installed
|
||||
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
||||
# ensure that the index already exists
|
||||
assert _check_index_exists(
|
||||
client, index_name
|
||||
), f"Index {index_name} does not exist"
|
||||
except Exception as e:
|
||||
raise ValueError(f"Redis failed to connect: {e}")
|
||||
|
||||
return cls(
|
||||
redis_url,
|
||||
index_name,
|
||||
embedding.embed_query,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
|
||||
tags = kwargs.pop("tags", None) or []
|
||||
tags.extend(self._get_retriever_tags())
|
||||
return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
|
||||
|
||||
|
||||
class RedisVectorStoreRetriever(VectorStoreRetriever):
|
||||
"""Retriever for `Redis` vector store."""
|
||||
|
||||
vectorstore: Redis
|
||||
"""Redis VectorStore."""
|
||||
search_type: str = "similarity"
|
||||
"""Type of search to perform. Can be either 'similarity' or 'similarity_limit'."""
|
||||
k: int = 4
|
||||
"""Number of documents to return."""
|
||||
score_threshold: float = 0.4
|
||||
"""Score threshold for similarity_limit search."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
if search_type not in ("similarity", "similarity_limit"):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
elif self.search_type == "similarity_limit":
|
||||
docs = self.vectorstore.similarity_search_limit_score(
|
||||
query, k=self.k, score_threshold=self.score_threshold
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return self.vectorstore.add_documents(documents, **kwargs)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return await self.vectorstore.aadd_documents(documents, **kwargs)
|
||||
9
libs/langchain/langchain/vectorstores/redis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .base import Redis
|
||||
from .filters import (
|
||||
RedisFilter,
|
||||
RedisNum,
|
||||
RedisTag,
|
||||
RedisText,
|
||||
)
|
||||
|
||||
__all__ = ["Redis", "RedisFilter", "RedisTag", "RedisText", "RedisNum"]
|
||||
1361
libs/langchain/langchain/vectorstores/redis/base.py
Normal file
20
libs/langchain/langchain/vectorstores/redis/constants.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# required modules
|
||||
REDIS_REQUIRED_MODULES = [
|
||||
{"name": "search", "ver": 20600},
|
||||
{"name": "searchlight", "ver": 20600},
|
||||
]
|
||||
|
||||
# distance metrics
|
||||
REDIS_DISTANCE_METRICS: List[str] = ["COSINE", "IP", "L2"]
|
||||
|
||||
# supported vector datatypes
|
||||
REDIS_VECTOR_DTYPE_MAP: Dict[str, Any] = {
|
||||
"FLOAT32": np.float32,
|
||||
"FLOAT64": np.float64,
|
||||
}
|
||||
|
||||
REDIS_TAG_SEPARATOR = ","
|
||||
420
libs/langchain/langchain/vectorstores/redis/filters.py
Normal file
@@ -0,0 +1,420 @@
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.utilities.redis import TokenEscaper
|
||||
|
||||
# disable mypy error for dunder method overrides
|
||||
# mypy: disable-error-code="override"
|
||||
|
||||
|
||||
class RedisFilterOperator(Enum):
|
||||
EQ = 1
|
||||
NE = 2
|
||||
LT = 3
|
||||
GT = 4
|
||||
LE = 5
|
||||
GE = 6
|
||||
OR = 7
|
||||
AND = 8
|
||||
LIKE = 9
|
||||
IN = 10
|
||||
|
||||
|
||||
class RedisFilter:
|
||||
@staticmethod
|
||||
def text(field: str) -> "RedisText":
|
||||
return RedisText(field)
|
||||
|
||||
@staticmethod
|
||||
def num(field: str) -> "RedisNum":
|
||||
return RedisNum(field)
|
||||
|
||||
@staticmethod
|
||||
def tag(field: str) -> "RedisTag":
|
||||
return RedisTag(field)
|
||||
|
||||
|
||||
class RedisFilterField:
|
||||
escaper: "TokenEscaper" = TokenEscaper()
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {}
|
||||
|
||||
def __init__(self, field: str):
|
||||
self._field = field
|
||||
self._value: Any = None
|
||||
self._operator: RedisFilterOperator = RedisFilterOperator.EQ
|
||||
|
||||
def equals(self, other: "RedisFilterField") -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return False
|
||||
return self._field == other._field and self._value == other._value
|
||||
|
||||
def _set_value(
|
||||
self, val: Any, val_type: type, operator: RedisFilterOperator
|
||||
) -> None:
|
||||
# check that the operator is supported by this class
|
||||
if operator not in self.OPERATORS:
|
||||
raise ValueError(
|
||||
f"Operator {operator} not supported by {self.__class__.__name__}. "
|
||||
+ f"Supported operators are {self.OPERATORS.values()}"
|
||||
)
|
||||
|
||||
if not isinstance(val, val_type):
|
||||
raise TypeError(
|
||||
f"Right side argument passed to operator {self.OPERATORS[operator]} "
|
||||
f"with left side "
|
||||
f"argument {self.__class__.__name__} must be of type {val_type}"
|
||||
)
|
||||
self._value = val
|
||||
self._operator = operator
|
||||
|
||||
|
||||
def check_operator_misuse(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
|
||||
# Extracting 'other' from positional arguments or keyword arguments
|
||||
other = kwargs.get("other") if "other" in kwargs else None
|
||||
if not other:
|
||||
for arg in args:
|
||||
if isinstance(arg, type(instance)):
|
||||
other = arg
|
||||
break
|
||||
|
||||
if isinstance(other, type(instance)):
|
||||
raise ValueError(
|
||||
"Equality operators are overridden for FilterExpression creation. Use "
|
||||
".equals() for equality checks"
|
||||
)
|
||||
return func(instance, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RedisTag(RedisFilterField):
|
||||
"""A RedisTag is a RedisFilterField representing a tag in a Redis index."""
|
||||
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.IN: "==",
|
||||
}
|
||||
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "@%s:{%s}",
|
||||
RedisFilterOperator.NE: "(-@%s:{%s})",
|
||||
RedisFilterOperator.IN: "@%s:{%s}",
|
||||
}
|
||||
|
||||
def __init__(self, field: str):
|
||||
"""Create a RedisTag FilterField
|
||||
|
||||
Args:
|
||||
field (str): The name of the RedisTag field in the index to be queried
|
||||
against.
|
||||
"""
|
||||
super().__init__(field)
|
||||
|
||||
def _set_tag_value(
|
||||
self, other: Union[List[str], str], operator: RedisFilterOperator
|
||||
) -> None:
|
||||
if isinstance(other, list):
|
||||
if not all(isinstance(tag, str) for tag in other):
|
||||
raise ValueError("All tags must be strings")
|
||||
else:
|
||||
other = [other]
|
||||
self._set_value(other, list, operator)
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag equality filter expression
|
||||
|
||||
Args:
|
||||
other (Union[List[str], str]): The tag(s) to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisTag
|
||||
>>> filter = RedisTag("brand") == "nike"
|
||||
"""
|
||||
self._set_tag_value(other, RedisFilterOperator.EQ)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag inequality filter expression
|
||||
|
||||
Args:
|
||||
other (Union[List[str], str]): The tag(s) to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisTag
|
||||
>>> filter = RedisTag("brand") != "nike"
|
||||
"""
|
||||
self._set_tag_value(other, RedisFilterOperator.NE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@property
|
||||
def _formatted_tag_value(self) -> str:
|
||||
return "|".join([self.escaper.escape(tag) for tag in self._value])
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
"""Return the Redis Query syntax for a RedisTag filter expression"""
|
||||
return self.OPERATOR_MAP[self._operator] % (
|
||||
self._field,
|
||||
self._formatted_tag_value,
|
||||
)
|
||||
|
||||
|
||||
class RedisNum(RedisFilterField):
|
||||
"""A RedisFilterField representing a numeric field in a Redis index."""
|
||||
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.LT: "<",
|
||||
RedisFilterOperator.GT: ">",
|
||||
RedisFilterOperator.LE: "<=",
|
||||
RedisFilterOperator.GE: ">=",
|
||||
}
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "@%s:[%i %i]",
|
||||
RedisFilterOperator.NE: "(-@%s:[%i %i])",
|
||||
RedisFilterOperator.GT: "@%s:[(%i +inf]",
|
||||
RedisFilterOperator.LT: "@%s:[-inf (%i]",
|
||||
RedisFilterOperator.GE: "@%s:[%i +inf]",
|
||||
RedisFilterOperator.LE: "@%s:[-inf %i]",
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the Redis Query syntax for a Numeric filter expression"""
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
|
||||
if (
|
||||
self._operator == RedisFilterOperator.EQ
|
||||
or self._operator == RedisFilterOperator.NE
|
||||
):
|
||||
return self.OPERATOR_MAP[self._operator] % (
|
||||
self._field,
|
||||
self._value,
|
||||
self._value,
|
||||
)
|
||||
else:
|
||||
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric equality filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") == 90210
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.EQ)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric inequality filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") != 90210
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.NE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __gt__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a RedisNumeric greater than filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") > 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.GT)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __lt__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") < 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.LT)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __ge__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric greater than or equal to filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") >= 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.GE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __le__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than or equal to filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") <= 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.LE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
|
||||
class RedisText(RedisFilterField):
|
||||
"""A RedisText is a RedisFilterField representing a text field in a Redis index."""
|
||||
|
||||
OPERATORS = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.LIKE: "%",
|
||||
}
|
||||
OPERATOR_MAP = {
|
||||
RedisFilterOperator.EQ: '@%s:"%s"',
|
||||
RedisFilterOperator.NE: '(-@%s:"%s")',
|
||||
RedisFilterOperator.LIKE: "@%s:%s",
|
||||
}
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText equality filter expression
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") == "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.EQ)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText inequality filter expression
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") != "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.NE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __mod__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText like filter expression
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") % "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.LIKE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
|
||||
try:
|
||||
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||
except KeyError:
|
||||
raise Exception("Invalid operator")
|
||||
|
||||
|
||||
class RedisFilterExpression:
|
||||
"""A RedisFilterExpression is a logical expression of RedisFilterFields.
|
||||
|
||||
RedisFilterExpressions can be combined using the & and | operators to create
|
||||
complex logical expressions that evaluate to the Redis Query language.
|
||||
|
||||
This presents an interface by which users can create complex queries
|
||||
without having to know the Redis Query language.
|
||||
|
||||
Filter expressions are not initialized directly. Instead they are built
|
||||
by combining RedisFilterFields using the & and | operators.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from langchain.vectorstores.redis import RedisTag, RedisNum
|
||||
>>> brand_is_nike = RedisTag("brand") == "nike"
|
||||
>>> price_is_under_100 = RedisNum("price") < 100
|
||||
>>> filter = brand_is_nike & price_is_under_100
|
||||
>>> print(str(filter))
|
||||
(@brand:{nike} @price:[-inf (100)])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_filter: Optional[str] = None,
|
||||
operator: Optional[RedisFilterOperator] = None,
|
||||
left: Optional["RedisFilterExpression"] = None,
|
||||
right: Optional["RedisFilterExpression"] = None,
|
||||
):
|
||||
self._filter = _filter
|
||||
self._operator = operator
|
||||
self._left = left
|
||||
self._right = right
|
||||
|
||||
def __and__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
|
||||
return RedisFilterExpression(
|
||||
operator=RedisFilterOperator.AND, left=self, right=other
|
||||
)
|
||||
|
||||
def __or__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
|
||||
return RedisFilterExpression(
|
||||
operator=RedisFilterOperator.OR, left=self, right=other
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# top level check that allows recursive calls to __str__
|
||||
if not self._filter and not self._operator:
|
||||
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||
|
||||
# allow for single filter expression without operators as last
|
||||
# expression in the chain might not have an operator
|
||||
if self._operator:
|
||||
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
|
||||
return f"({str(self._left)}{operator_str}{str(self._right)})"
|
||||
|
||||
# check that base case, the filter is set
|
||||
if not self._filter:
|
||||
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||
return self._filter
|
||||
276
libs/langchain/langchain/vectorstores/redis/schema.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
# ignore type error here as it's a redis-py type problem
|
||||
from redis.commands.search.field import ( # type: ignore
|
||||
NumericField,
|
||||
TagField,
|
||||
TextField,
|
||||
VectorField,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field, validator
|
||||
from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
||||
|
||||
|
||||
class RedisDistanceMetric(str, Enum):
|
||||
l2 = "L2"
|
||||
cosine = "COSINE"
|
||||
ip = "IP"
|
||||
|
||||
|
||||
class RedisField(BaseModel):
|
||||
name: str = Field(...)
|
||||
|
||||
|
||||
class TextFieldSchema(RedisField):
|
||||
weight: float = 1
|
||||
no_stem: bool = False
|
||||
phonetic_matcher: Optional[str] = None
|
||||
withsuffixtrie: bool = False
|
||||
no_index: bool = False
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TextField:
|
||||
return TextField(
|
||||
self.name,
|
||||
weight=self.weight,
|
||||
no_stem=self.no_stem,
|
||||
phonetic_matcher=self.phonetic_matcher,
|
||||
sortable=self.sortable,
|
||||
no_index=self.no_index,
|
||||
)
|
||||
|
||||
|
||||
class TagFieldSchema(RedisField):
|
||||
separator: str = ","
|
||||
case_sensitive: bool = False
|
||||
no_index: bool = False
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TagField:
|
||||
return TagField(
|
||||
self.name,
|
||||
separator=self.separator,
|
||||
case_sensitive=self.case_sensitive,
|
||||
sortable=self.sortable,
|
||||
no_index=self.no_index,
|
||||
)
|
||||
|
||||
|
||||
class NumericFieldSchema(RedisField):
|
||||
no_index: bool = False
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> NumericField:
|
||||
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
|
||||
|
||||
|
||||
class RedisVectorField(RedisField):
|
||||
dims: int = Field(...)
|
||||
algorithm: object = Field(...)
|
||||
datatype: str = Field(default="FLOAT32")
|
||||
distance_metric: RedisDistanceMetric = Field(default="COSINE")
|
||||
initial_cap: int = Field(default=20000)
|
||||
|
||||
@validator("distance_metric", pre=True)
|
||||
def uppercase_strings(cls, v: str) -> str:
|
||||
return v.upper()
|
||||
|
||||
@validator("datatype", pre=True)
|
||||
def uppercase_and_check_dtype(cls, v: str) -> str:
|
||||
if v.upper() not in REDIS_VECTOR_DTYPE_MAP:
|
||||
raise ValueError(
|
||||
f"datatype must be one of {REDIS_VECTOR_DTYPE_MAP.keys()}. Got {v}"
|
||||
)
|
||||
return v.upper()
|
||||
|
||||
|
||||
class FlatVectorField(RedisVectorField):
|
||||
algorithm: Literal["FLAT"] = "FLAT"
|
||||
block_size: int = Field(default=1000)
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
return VectorField(
|
||||
self.name,
|
||||
self.algorithm,
|
||||
{
|
||||
"TYPE": self.datatype,
|
||||
"DIM": self.dims,
|
||||
"DISTANCE_METRIC": self.distance_metric,
|
||||
"INITIAL_CAP": self.initial_cap,
|
||||
"BLOCK_SIZE": self.block_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class HNSWVectorField(RedisVectorField):
|
||||
algorithm: Literal["HNSW"] = "HNSW"
|
||||
m: int = Field(default=16)
|
||||
ef_construction: int = Field(default=200)
|
||||
ef_runtime: int = Field(default=10)
|
||||
epsilon: float = Field(default=0.8)
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
return VectorField(
|
||||
self.name,
|
||||
self.algorithm,
|
||||
{
|
||||
"TYPE": self.datatype,
|
||||
"DIM": self.dims,
|
||||
"DISTANCE_METRIC": self.distance_metric,
|
||||
"INITIAL_CAP": self.initial_cap,
|
||||
"M": self.m,
|
||||
"EF_CONSTRUCTION": self.ef_construction,
|
||||
"EF_RUNTIME": self.ef_runtime,
|
||||
"EPSILON": self.epsilon,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
# always have a content field for text
|
||||
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
|
||||
tag: Optional[List[TagFieldSchema]] = None
|
||||
numeric: Optional[List[NumericFieldSchema]] = None
|
||||
extra: Optional[List[RedisField]] = None
|
||||
|
||||
# filled by default_vector_schema
|
||||
vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
|
||||
content_key: str = "content"
|
||||
content_vector_key: str = "content_vector"
|
||||
|
||||
def add_content_field(self) -> None:
|
||||
if self.text is None:
|
||||
self.text = []
|
||||
for field in self.text:
|
||||
if field.name == self.content_key:
|
||||
return
|
||||
self.text.append(TextFieldSchema(name=self.content_key))
|
||||
|
||||
def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
|
||||
# catch case where user inputted no vector field spec
|
||||
# in the index schema
|
||||
if self.vector is None:
|
||||
self.vector = []
|
||||
|
||||
# ignore types as pydantic is handling type validation and conversion
|
||||
if vector_field["algorithm"] == "FLAT":
|
||||
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
|
||||
elif vector_field["algorithm"] == "HNSW":
|
||||
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
|
||||
else:
|
||||
raise ValueError(
|
||||
f"algorithm must be either FLAT or HNSW. Got "
|
||||
f"{vector_field['algorithm']}"
|
||||
)
|
||||
|
||||
def as_dict(self) -> Dict[str, List[Any]]:
|
||||
schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
|
||||
# iter over all class attributes
|
||||
for attr, attr_value in self.__dict__.items():
|
||||
# only non-empty lists
|
||||
if isinstance(attr_value, list) and len(attr_value) > 0:
|
||||
field_values: List[Dict[str, Any]] = []
|
||||
# iterate over all fields in each category (tag, text, etc)
|
||||
for val in attr_value:
|
||||
value: Dict[str, Any] = {}
|
||||
# iterate over values within each field to extract
|
||||
# settings for that field (i.e. name, weight, etc)
|
||||
for field, field_value in val.__dict__.items():
|
||||
# make enums into strings
|
||||
if isinstance(field_value, Enum):
|
||||
value[field] = field_value.value
|
||||
# don't write null values
|
||||
elif field_value is not None:
|
||||
value[field] = field_value
|
||||
field_values.append(value)
|
||||
|
||||
schemas[attr] = field_values
|
||||
|
||||
schema: Dict[str, List[Any]] = {}
|
||||
# only write non-empty lists from defaults
|
||||
for k, v in schemas.items():
|
||||
if len(v) > 0:
|
||||
schema[k] = v
|
||||
return schema
|
||||
|
||||
@property
|
||||
def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
|
||||
if not self.vector:
|
||||
raise ValueError("No vector fields found")
|
||||
for field in self.vector:
|
||||
if field.name == self.content_vector_key:
|
||||
return field
|
||||
raise ValueError("No content_vector field found")
|
||||
|
||||
@property
|
||||
def vector_dtype(self) -> np.dtype:
|
||||
# should only ever be called after pydantic has validated the schema
|
||||
return REDIS_VECTOR_DTYPE_MAP[self.content_vector.datatype]
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return all(
|
||||
field is None for field in [self.tag, self.text, self.numeric, self.vector]
|
||||
)
|
||||
|
||||
def get_fields(self) -> List["RedisField"]:
|
||||
redis_fields: List["RedisField"] = []
|
||||
if self.is_empty:
|
||||
return redis_fields
|
||||
|
||||
for field_name in self.__fields__.keys():
|
||||
if field_name not in ["content_key", "content_vector_key", "extra"]:
|
||||
field_group = getattr(self, field_name)
|
||||
if field_group is not None:
|
||||
for field in field_group:
|
||||
redis_fields.append(field.as_field())
|
||||
return redis_fields
|
||||
|
||||
@property
|
||||
def metadata_keys(self) -> List[str]:
|
||||
keys: List[str] = []
|
||||
if self.is_empty:
|
||||
return keys
|
||||
|
||||
for field_name in self.__fields__.keys():
|
||||
field_group = getattr(self, field_name)
|
||||
if field_group is not None:
|
||||
for field in field_group:
|
||||
# check if it's a metadata field. exclude vector and content key
|
||||
if not isinstance(field, str) and field.name not in [
|
||||
self.content_key,
|
||||
self.content_vector_key,
|
||||
]:
|
||||
keys.append(field.name)
|
||||
return keys
|
||||
|
||||
|
||||
def read_schema(
|
||||
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
|
||||
) -> Dict[str, Any]:
|
||||
# check if its a dict and return RedisModel otherwise, check if it's a path and
|
||||
# read in the file assuming it's a yaml file and return a RedisModel
|
||||
if isinstance(index_schema, dict):
|
||||
return index_schema
|
||||
elif isinstance(index_schema, Path):
|
||||
with open(index_schema, "rb") as f:
|
||||
return yaml.safe_load(f)
|
||||
elif isinstance(index_schema, str):
|
||||
if Path(index_schema).resolve().is_file():
|
||||
with open(index_schema, "rb") as f:
|
||||
return yaml.safe_load(f)
|
||||
else:
|
||||
raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"index_schema must be a dict, or path to a yaml file "
|
||||
f"Got {type(index_schema)}"
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain"
|
||||
version = "0.0.273"
|
||||
version = "0.0.274"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
|
||||
@@ -1,16 +1,27 @@
|
||||
"""Test Redis cache functionality."""
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import RedisCache, RedisSemanticCache
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
REDIS_TEST_URL = "redis://localhost:6379"
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def test_redis_cache_ttl() -> None:
|
||||
import redis
|
||||
|
||||
@@ -30,12 +41,10 @@ def test_redis_cache() -> None:
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output)
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
print(expected_output)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
@@ -80,14 +89,90 @@ def test_redis_semantic_cache() -> None:
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_chat() -> None:
|
||||
import redis
|
||||
def test_redis_semantic_cache_multi() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||
|
||||
def test_redis_semantic_cache_chat() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
with pytest.warns():
|
||||
llm.predict("foo")
|
||||
llm.predict("foo")
|
||||
langchain.llm_cache.redis.flushall()
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
|
||||
@pytest.mark.parametrize(
|
||||
"prompts, generations",
|
||||
[
|
||||
# Single prompt, single generation
|
||||
([random_string()], [[random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||
# Multiple prompts, multiple generations
|
||||
(
|
||||
[random_string(), random_string()],
|
||||
[[random_string()], [random_string(), random_string()]],
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"single_prompt_single_generation",
|
||||
"single_prompt_multiple_generations",
|
||||
"single_prompt_multiple_generations",
|
||||
"multiple_prompts_multiple_generations",
|
||||
],
|
||||
)
|
||||
def test_redis_semantic_cache_hit(
|
||||
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
|
||||
) -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=embedding, redis_url=REDIS_TEST_URL
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
llm_generations = [
|
||||
[
|
||||
Generation(text=generation, generation_info=params)
|
||||
for generation in prompt_i_generations
|
||||
]
|
||||
for prompt_i_generations in generations
|
||||
]
|
||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||
print(prompt_i)
|
||||
print(llm_generations_i)
|
||||
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||
llm.generate(prompts)
|
||||
assert llm.generate(prompts) == LLMResult(
|
||||
generations=llm_generations, llm_output={}
|
||||
)
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chains.openai_functions import (
|
||||
create_openai_fn_chain,
|
||||
)
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@@ -191,6 +195,108 @@ async def test_async_chat_openai_streaming() -> None:
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai_streaming_with_function() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
|
||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._captured_tokens: List[str] = []
|
||||
self._captured_chunks: List[
|
||||
Optional[Union[ChatGenerationChunk, GenerationChunk]]
|
||||
] = []
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._captured_tokens.append(token)
|
||||
self._captured_chunks.append(chunk)
|
||||
|
||||
json_schema = {
|
||||
"title": "Person",
|
||||
"description": "Identifying information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "The person's name",
|
||||
"type": "string",
|
||||
},
|
||||
"age": {
|
||||
"title": "Age",
|
||||
"description": "The person's age",
|
||||
"type": "integer",
|
||||
},
|
||||
"fav_food": {
|
||||
"title": "Fav Food",
|
||||
"description": "The person's favorite food",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
|
||||
callback_handler = MyCustomAsyncHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
n=1,
|
||||
callback_manager=callback_manager,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
prompt_msgs = [
|
||||
SystemMessage(
|
||||
content="You are a world class algorithm for "
|
||||
"extracting information in structured formats."
|
||||
),
|
||||
HumanMessage(
|
||||
content="Use the given format to extract "
|
||||
"information from the following input:"
|
||||
),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
HumanMessage(content="Tips: Make sure to answer in the correct format"),
|
||||
]
|
||||
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||
|
||||
function: Any = {
|
||||
"name": "output_formatter",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": json_schema,
|
||||
}
|
||||
chain = create_openai_fn_chain(
|
||||
[function],
|
||||
chat,
|
||||
prompt,
|
||||
output_parser=None,
|
||||
)
|
||||
|
||||
message = HumanMessage(content="Sally is 13 years old")
|
||||
response = await chain.agenerate([{"input": message}])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 1
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
assert len(callback_handler._captured_tokens) > 0
|
||||
assert len(callback_handler._captured_chunks) > 0
|
||||
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
|
||||
|
||||
|
||||
def test_chat_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
|
||||
@@ -52,6 +52,7 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
return self.embed_documents([text])[0]
|
||||
if text not in self.known_texts:
|
||||
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
|
||||
return [float(1.0)] * (self.dimensionality - 1) + [
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
"""Test Redis functionality."""
|
||||
from typing import List
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.redis import Redis
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from langchain.vectorstores.redis import (
|
||||
Redis,
|
||||
RedisFilter,
|
||||
RedisNum,
|
||||
RedisText,
|
||||
)
|
||||
from langchain.vectorstores.redis.filters import RedisFilterExpression
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
TEST_INDEX_NAME = "test"
|
||||
TEST_REDIS_URL = "redis://localhost:6379"
|
||||
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
||||
TEST_SINGLE_WITH_METADATA_RESULT = [Document(page_content="foo", metadata={"a": "b"})]
|
||||
TEST_SINGLE_WITH_METADATA = {"a": "b"}
|
||||
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
||||
RANGE_SCORE = pytest.approx(0.0513, abs=0.002)
|
||||
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
|
||||
IP_SCORE = -8.0
|
||||
EUCLIDEAN_SCORE = 1.0
|
||||
@@ -23,6 +34,27 @@ def drop(index_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def convert_bytes(data: Any) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
return data.decode("ascii")
|
||||
if isinstance(data, dict):
|
||||
return dict(map(convert_bytes, data.items()))
|
||||
if isinstance(data, list):
|
||||
return list(map(convert_bytes, data))
|
||||
if isinstance(data, tuple):
|
||||
return map(convert_bytes, data)
|
||||
return data
|
||||
|
||||
|
||||
def make_dict(values: List[Any]) -> dict:
|
||||
i = 0
|
||||
di = {}
|
||||
while i < len(values) - 1:
|
||||
di[values[i]] = values[i + 1]
|
||||
i += 2
|
||||
return di
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def texts() -> List[str]:
|
||||
return ["foo", "bar", "baz"]
|
||||
@@ -31,7 +63,7 @@ def texts() -> List[str]:
|
||||
def test_redis(texts: List[str]) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
@@ -40,30 +72,55 @@ def test_redis_new_vector(texts: List[str]) -> None:
|
||||
"""Test adding a new document"""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_from_existing(texts: List[str]) -> None:
|
||||
"""Test adding a new document"""
|
||||
Redis.from_texts(
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
)
|
||||
schema: Dict = docsearch.schema
|
||||
|
||||
# write schema for the next test
|
||||
docsearch.write_schema("test_schema.yml")
|
||||
|
||||
# Test creating from an existing
|
||||
docsearch2 = Redis.from_existing_index(
|
||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
FakeEmbeddings(),
|
||||
index_name=TEST_INDEX_NAME,
|
||||
redis_url=TEST_REDIS_URL,
|
||||
schema=schema,
|
||||
)
|
||||
output = docsearch2.similarity_search("foo", k=1)
|
||||
output = docsearch2.similarity_search("foo", k=1, return_metadata=False)
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
|
||||
|
||||
def test_redis_add_texts_to_existing() -> None:
|
||||
"""Test adding a new document"""
|
||||
# Test creating from an existing with yaml from file
|
||||
docsearch = Redis.from_existing_index(
|
||||
FakeEmbeddings(),
|
||||
index_name=TEST_INDEX_NAME,
|
||||
redis_url=TEST_REDIS_URL,
|
||||
schema="test_schema.yml",
|
||||
)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(TEST_INDEX_NAME)
|
||||
# remove the test_schema.yml file
|
||||
os.remove("test_schema.yml")
|
||||
|
||||
|
||||
def test_redis_from_texts_return_keys(texts: List[str]) -> None:
|
||||
"""Test from_texts_return_keys constructor."""
|
||||
docsearch, keys = Redis.from_texts_return_keys(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
assert len(keys) == len(texts)
|
||||
assert drop(docsearch.index_name)
|
||||
@@ -73,21 +130,124 @@ def test_redis_from_documents(texts: List[str]) -> None:
|
||||
"""Test from_documents constructor."""
|
||||
docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts]
|
||||
docsearch = Redis.from_documents(docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == TEST_SINGLE_WITH_METADATA_RESULT
|
||||
output = docsearch.similarity_search("foo", k=1, return_metadata=True)
|
||||
assert "a" in output[0].metadata.keys()
|
||||
assert "b" in output[0].metadata.values()
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_add_texts_to_existing() -> None:
|
||||
"""Test adding a new document"""
|
||||
# Test creating from an existing
|
||||
docsearch = Redis.from_existing_index(
|
||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
# -- test filters -- #
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_expr, expected_length, expected_nums",
|
||||
[
|
||||
(RedisText("text") == "foo", 1, None),
|
||||
(RedisFilter.text("text") == "foo", 1, None),
|
||||
(RedisText("text") % "ba*", 2, ["bar", "baz"]),
|
||||
(RedisNum("num") > 2, 1, [3]),
|
||||
(RedisNum("num") < 2, 1, [1]),
|
||||
(RedisNum("num") >= 2, 2, [2, 3]),
|
||||
(RedisNum("num") <= 2, 2, [1, 2]),
|
||||
(RedisNum("num") != 2, 2, [1, 3]),
|
||||
(RedisFilter.num("num") != 2, 2, [1, 3]),
|
||||
(RedisFilter.tag("category") == "a", 3, None),
|
||||
(RedisFilter.tag("category") == "b", 2, None),
|
||||
(RedisFilter.tag("category") == "c", 2, None),
|
||||
(RedisFilter.tag("category") == ["b", "c"], 3, None),
|
||||
],
|
||||
ids=[
|
||||
"text-filter-equals-foo",
|
||||
"alternative-text-equals-foo",
|
||||
"text-filter-fuzzy-match-ba",
|
||||
"number-filter-greater-than-2",
|
||||
"number-filter-less-than-2",
|
||||
"number-filter-greater-equals-2",
|
||||
"number-filter-less-equals-2",
|
||||
"number-filter-not-equals-2",
|
||||
"alternative-number-not-equals-2",
|
||||
"tag-filter-equals-a",
|
||||
"tag-filter-equals-b",
|
||||
"tag-filter-equals-c",
|
||||
"tag-filter-equals-b-or-c",
|
||||
],
|
||||
)
|
||||
def test_redis_filters_1(
|
||||
filter_expr: RedisFilterExpression,
|
||||
expected_length: int,
|
||||
expected_nums: Optional[list],
|
||||
) -> None:
|
||||
metadata = [
|
||||
{"name": "joe", "num": 1, "text": "foo", "category": ["a", "b"]},
|
||||
{"name": "john", "num": 2, "text": "bar", "category": ["a", "c"]},
|
||||
{"name": "jane", "num": 3, "text": "baz", "category": ["b", "c", "a"]},
|
||||
]
|
||||
documents = [Document(page_content="foo", metadata=m) for m in metadata]
|
||||
docsearch = Redis.from_documents(
|
||||
documents, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(TEST_INDEX_NAME)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=3, filter=filter_expr)
|
||||
|
||||
assert len(output) == expected_length
|
||||
|
||||
if expected_nums is not None:
|
||||
for out in output:
|
||||
assert (
|
||||
out.metadata["text"] in expected_nums
|
||||
or int(out.metadata["num"]) in expected_nums
|
||||
)
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
# -- test index specification -- #
|
||||
|
||||
|
||||
def test_index_specification_generation() -> None:
|
||||
index_schema = {
|
||||
"text": [{"name": "job"}, {"name": "title"}],
|
||||
"numeric": [{"name": "salary"}],
|
||||
}
|
||||
|
||||
text = ["foo"]
|
||||
meta = {"job": "engineer", "title": "principal engineer", "salary": 100000}
|
||||
docs = [Document(page_content=t, metadata=meta) for t in text]
|
||||
r = Redis.from_documents(
|
||||
docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL, index_schema=index_schema
|
||||
)
|
||||
|
||||
output = r.similarity_search("foo", k=1, return_metadata=True)
|
||||
assert output[0].metadata["job"] == "engineer"
|
||||
assert output[0].metadata["title"] == "principal engineer"
|
||||
assert int(output[0].metadata["salary"]) == 100000
|
||||
|
||||
info = convert_bytes(r.client.ft(r.index_name).info())
|
||||
attributes = info["attributes"]
|
||||
assert len(attributes) == 5
|
||||
for attr in attributes:
|
||||
d = make_dict(attr)
|
||||
if d["identifier"] == "job":
|
||||
assert d["type"] == "TEXT"
|
||||
elif d["identifier"] == "title":
|
||||
assert d["type"] == "TEXT"
|
||||
elif d["identifier"] == "salary":
|
||||
assert d["type"] == "NUMERIC"
|
||||
elif d["identifier"] == "content":
|
||||
assert d["type"] == "TEXT"
|
||||
elif d["identifier"] == "content_vector":
|
||||
assert d["type"] == "VECTOR"
|
||||
else:
|
||||
raise ValueError("Unexpected attribute in index schema")
|
||||
|
||||
assert drop(r.index_name)
|
||||
|
||||
|
||||
# -- test distance metrics -- #
|
||||
|
||||
cosine_schema: Dict = {"distance_metric": "cosine"}
|
||||
ip_schema: Dict = {"distance_metric": "IP"}
|
||||
l2_schema: Dict = {"distance_metric": "L2"}
|
||||
|
||||
|
||||
def test_cosine(texts: List[str]) -> None:
|
||||
@@ -96,7 +256,7 @@ def test_cosine(texts: List[str]) -> None:
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
redis_url=TEST_REDIS_URL,
|
||||
distance_metric="COSINE",
|
||||
vector_schema=cosine_schema,
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
@@ -107,7 +267,7 @@ def test_cosine(texts: List[str]) -> None:
|
||||
def test_l2(texts: List[str]) -> None:
|
||||
"""Test Flat L2 distance."""
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2"
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=l2_schema
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
@@ -118,7 +278,7 @@ def test_l2(texts: List[str]) -> None:
|
||||
def test_ip(texts: List[str]) -> None:
|
||||
"""Test inner product distance."""
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP"
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=ip_schema
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
@@ -126,29 +286,34 @@ def test_ip(texts: List[str]) -> None:
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_similarity_search_limit_score(texts: List[str]) -> None:
|
||||
def test_similarity_search_limit_distance(texts: List[str]) -> None:
|
||||
"""Test similarity search limit score."""
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
redis_url=TEST_REDIS_URL,
|
||||
)
|
||||
output = docsearch.similarity_search_limit_score("far", k=2, score_threshold=0.1)
|
||||
assert len(output) == 1
|
||||
_, score = output[0]
|
||||
assert score == COSINE_SCORE
|
||||
output = docsearch.similarity_search(texts[0], k=3, distance_threshold=0.1)
|
||||
|
||||
# can't check score but length of output should be 2
|
||||
assert len(output) == 2
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_similarity_search_with_score_with_limit_score(texts: List[str]) -> None:
|
||||
def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> None:
|
||||
"""Test similarity search with score with limit score."""
|
||||
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
|
||||
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores(
|
||||
"far", k=2, score_threshold=0.1
|
||||
output = docsearch.similarity_search_with_score(
|
||||
texts[0], k=3, distance_threshold=0.1, return_metadata=True
|
||||
)
|
||||
assert len(output) == 1
|
||||
_, score = output[0]
|
||||
assert score == COSINE_SCORE
|
||||
|
||||
assert len(output) == 2
|
||||
for out, score in output:
|
||||
if out.page_content == texts[1]:
|
||||
score == COSINE_SCORE
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
@@ -156,6 +321,48 @@ def test_delete(texts: List[str]) -> None:
|
||||
"""Test deleting a new document"""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
ids = docsearch.add_texts(["foo"])
|
||||
got = docsearch.delete(ids=ids)
|
||||
got = docsearch.delete(ids=ids, redis_url=TEST_REDIS_URL)
|
||||
assert got
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_as_retriever() -> None:
|
||||
texts = ["foo", "foo", "foo", "foo", "bar"]
|
||||
docsearch = Redis.from_texts(
|
||||
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
||||
results = retriever.get_relevant_documents("foo")
|
||||
assert len(results) == 3
|
||||
assert all([d.page_content == "foo" for d in results])
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_retriever_distance_threshold() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_distance_threshold",
|
||||
search_kwargs={"k": 3, "distance_threshold": 0.1},
|
||||
)
|
||||
results = retriever.get_relevant_documents("foo")
|
||||
assert len(results) == 2
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_retriever_score_threshold() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.91},
|
||||
)
|
||||
results = retriever.get_relevant_documents("foo")
|
||||
assert len(results) == 2
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
@@ -12,6 +12,21 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
"This Agreement is governed by English law.\n",
|
||||
"28-pl",
|
||||
),
|
||||
(
|
||||
"This Agreement is governed by English law.\nSources: 28-pl",
|
||||
"This Agreement is governed by English law.\n",
|
||||
"28-pl",
|
||||
),
|
||||
(
|
||||
"This Agreement is governed by English law.\nsource: 28-pl",
|
||||
"This Agreement is governed by English law.\n",
|
||||
"28-pl",
|
||||
),
|
||||
(
|
||||
"This Agreement is governed by English law.\nSource: 28-pl",
|
||||
"This Agreement is governed by English law.\n",
|
||||
"28-pl",
|
||||
),
|
||||
(
|
||||
"This Agreement is governed by English law.\n"
|
||||
"SOURCES: 28-pl\n\n"
|
||||
|
||||