mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-07 17:50:35 +00:00
Compare commits
1 Commits
harrison/a
...
harrison/x
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34d3d5d807 |
@@ -10,7 +10,7 @@ import argparse
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
# Base URL for all class documentation
|
||||
_BASE_URL = "https://api.python.langchain.com/en/latest/"
|
||||
_BASE_URL = "https://api.python.langchain.com/en/latest"
|
||||
|
||||
# Regular expression to match Python code blocks
|
||||
code_block_re = re.compile(r"^(```python\n)(.*?)(```\n)", re.DOTALL | re.MULTILINE)
|
||||
|
||||
@@ -3476,10 +3476,6 @@
|
||||
"source": "/en/latest/modules/prompts/output_parsers/examples/retry.html",
|
||||
"destination": "/docs/modules/model_io/output_parsers/retry"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/modules/prompts/example_selectors.html",
|
||||
"destination": "/docs/modules/model_io/example_selectors"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/modules/prompts/example_selectors/examples/custom_example_selector.html",
|
||||
"destination": "/docs/modules/model_io/prompts/example_selectors/custom_example_selector"
|
||||
@@ -3492,10 +3488,6 @@
|
||||
"source": "/en/latest/modules/prompts/example_selectors/examples/ngram_overlap.html",
|
||||
"destination": "/docs/modules/model_io/prompts/example_selectors/ngram_overlap"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/modules/prompts/prompt_templates.html",
|
||||
"destination": "/docs/modules/model_io/prompt_templates"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/modules/prompts/prompt_templates/examples/connecting_to_a_feature_store.html",
|
||||
"destination": "/docs/modules/model_io/prompts/prompt_templates/connecting_to_a_feature_store"
|
||||
@@ -3748,10 +3740,6 @@
|
||||
"source": "/docs/modules/evaluation/:path*(/?)",
|
||||
"destination": "/docs/guides/evaluation/:path*"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/modules/indexes.html",
|
||||
"destination": "/docs/modules/data_connection"
|
||||
},
|
||||
{
|
||||
"source": "/en/latest/modules/indexes/:path*",
|
||||
"destination": "/docs/modules/data_connection/:path*"
|
||||
|
||||
@@ -22,19 +22,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 4,
|
||||
"id": "466b65b3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
@@ -42,7 +33,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 5,
|
||||
"id": "3c634ef0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -592,98 +583,6 @@
|
||||
"chain2.invoke({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d094d637",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Router\n",
|
||||
"\n",
|
||||
"You can also use the router runnable to conditionally route inputs to different runnables."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "252625fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import create_tagging_chain_pydantic\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"\n",
|
||||
"class PromptToUse(BaseModel):\n",
|
||||
" \"\"\"Used to determine which prompt to use to answer the user's input.\"\"\"\n",
|
||||
" \n",
|
||||
" name: str = Field(description=\"Should be one of `math` or `english`\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "57886e84",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tagger = create_tagging_chain_pydantic(PromptToUse, ChatOpenAI(temperature=0))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a303b089",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain1 = ChatPromptTemplate.from_template(\"You are a math genius. Answer the question: {question}\") | ChatOpenAI()\n",
|
||||
"chain2 = ChatPromptTemplate.from_template(\"You are an english major. Answer the question: {question}\") | ChatOpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7aa9ea06",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema.runnable import RouterRunnable\n",
|
||||
"router = RouterRunnable({\"math\": chain1, \"english\": chain2})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "6a3d3f5d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = {\n",
|
||||
" \"key\": {\"input\": lambda x: x[\"question\"]} | tagger | (lambda x: x['text'].name),\n",
|
||||
" \"input\": {\"question\": lambda x: x[\"question\"]}\n",
|
||||
"} | router"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "8aeda930",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Thank you for the compliment! The sum of 2 + 2 is equal to 4.', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"question\": \"whats 2 + 2\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29781123",
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5125a1e3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Anthropic Functions\n",
|
||||
"\n",
|
||||
"This notebook shows how to use an experimental wrapper around Anthropic that gives it the same API as OpenAI Functions."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "378be79b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_experimental.llms.anthropic_functions import AnthropicFunctions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "65499965",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initialize Model\n",
|
||||
"\n",
|
||||
"You can initialize this wrapper the same way you'd initialize ChatAnthropic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "e1d535f6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = AnthropicFunctions(model='claude-2')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fcc9eaf4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Passing in functions\n",
|
||||
"\n",
|
||||
"You can now pass in functions in a similar way"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "0779c320",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"functions=[\n",
|
||||
" {\n",
|
||||
" \"name\": \"get_current_weather\",\n",
|
||||
" \"description\": \"Get the current weather in a given location\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The city and state, e.g. San Francisco, CA\"\n",
|
||||
" },\n",
|
||||
" \"unit\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"enum\": [\"celsius\", \"fahrenheit\"]\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" ]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "ad75a933",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "fc703085",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = model.predict_messages(\n",
|
||||
" [HumanMessage(content=\"whats the weater in boston?\")], \n",
|
||||
" functions=functions\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "04d7936a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=' ', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\", \"unit\": \"fahrenheit\"}'}}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0072fdba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using for extraction\n",
|
||||
"\n",
|
||||
"You can now use this for extraction."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "7af5c567",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import create_extraction_chain\n",
|
||||
"schema = {\n",
|
||||
" \"properties\": {\n",
|
||||
" \"name\": {\"type\": \"string\"},\n",
|
||||
" \"height\": {\"type\": \"integer\"},\n",
|
||||
" \"hair_color\": {\"type\": \"string\"},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"name\", \"height\"],\n",
|
||||
"}\n",
|
||||
"inp = \"\"\"\n",
|
||||
"Alex is 5 feet tall. Claudia is 1 feet taller Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.\n",
|
||||
" \"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "bd01082a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = create_extraction_chain(schema, model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "b5a23e9f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'Alex', 'height': '5', 'hair_color': 'blonde'},\n",
|
||||
" {'name': 'Claudia', 'height': '6', 'hair_color': 'brunette'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(inp)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "90ec959e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using for tagging\n",
|
||||
"\n",
|
||||
"You can now use this for tagging"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "03c1eb0d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import create_tagging_chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "581c0ece",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"schema = {\n",
|
||||
" \"properties\": {\n",
|
||||
" \"sentiment\": {\"type\": \"string\"},\n",
|
||||
" \"aggressiveness\": {\"type\": \"integer\"},\n",
|
||||
" \"language\": {\"type\": \"string\"},\n",
|
||||
" }\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "d9a8570e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = create_tagging_chain(schema, model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "cf37d679",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'sentiment': 'positive', 'aggressiveness': '0', 'language': 'english'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(\"this is really cool\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -215,23 +215,10 @@
|
||||
"Chroma has the ability to handle multiple `Collections` of documents, but the LangChain interface expects one, so we need to specify the collection name. The default collection name used by LangChain is \"langchain\".\n",
|
||||
"\n",
|
||||
"Here is how to clone, build, and run the Docker Image:\n",
|
||||
"```sh\n",
|
||||
"```\n",
|
||||
"git clone git@github.com:chroma-core/chroma.git\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Edit the `docker-compose.yml` file and add `ALLOW_RESET=TRUE` under `environment`\n",
|
||||
"```yaml\n",
|
||||
" ...\n",
|
||||
" command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml\n",
|
||||
" environment:\n",
|
||||
" - IS_PERSISTENT=TRUE\n",
|
||||
" - ALLOW_RESET=TRUE\n",
|
||||
" ports:\n",
|
||||
" - 8000:8000\n",
|
||||
" ...\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Then run `docker-compose up -d --build`"
|
||||
"docker-compose up -d --build\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -23,9 +23,9 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install faiss-gpu # For CUDA 7.5+ Supported GPU's.\n",
|
||||
"#!pip install faiss\n",
|
||||
"# OR\n",
|
||||
"!pip install faiss-cpu # For CPU Installation"
|
||||
"!pip install faiss-cpu"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -202,7 +202,7 @@
|
||||
"qdrant = Qdrant.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" url=url,\n",
|
||||
" url,\n",
|
||||
" prefer_grpc=True,\n",
|
||||
" collection_name=\"my_documents\",\n",
|
||||
")"
|
||||
@@ -236,7 +236,7 @@
|
||||
"qdrant = Qdrant.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" url=url,\n",
|
||||
" url,\n",
|
||||
" prefer_grpc=True,\n",
|
||||
" api_key=api_key,\n",
|
||||
" collection_name=\"my_documents\",\n",
|
||||
@@ -270,7 +270,7 @@
|
||||
"qdrant = Qdrant.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" url=url,\n",
|
||||
" url,\n",
|
||||
" prefer_grpc=True,\n",
|
||||
" collection_name=\"my_documents\",\n",
|
||||
" force_recreate=True,\n",
|
||||
|
||||
@@ -2,141 +2,131 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9787b308",
|
||||
"id": "20b588b4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Rockset\n",
|
||||
"\n",
|
||||
">[Rockset](https://rockset.com/) is a real-time search and analytics database built for the cloud. Rockset uses a [Converged Index™](https://rockset.com/blog/converged-indexing-the-secret-sauce-behind-rocksets-fast-queries/) with an efficient store for vector embeddings to serve low latency, high concurrency search queries at scale. Rockset has full support for metadata filtering and handles real-time ingestion for constantly updating, streaming data.\n",
|
||||
">[Rockset](https://rockset.com/product/) is a real-time analytics database service for serving low latency, high concurrency analytical queries at scale. It builds a Converged Index™ on structured and semi-structured data with an efficient store for vector embeddings. Its support for running SQL on schemaless data makes it a perfect choice for running vector search with metadata filters. \n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to use `Rockset` as a vector store in LangChain. Before getting started, make sure you have access to a `Rockset` account and an API key available. [Start your free trial today.](https://rockset.com/create/)\n"
|
||||
"This notebook demonstrates how to use `Rockset` as a vectorstore in langchain. To get started, make sure you have a `Rockset` account and an API key available."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b823d64a",
|
||||
"id": "e290ddc0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setting Up Your Environment[](https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/rockset#setting-up-environment)\n",
|
||||
"## Setting up environment\n",
|
||||
"\n",
|
||||
"1. Leverage the `Rockset` console to create a [collection](https://rockset.com/docs/collections/) with the Write API as your source. In this walkthrough, we create a collection named `langchain_demo`. \n",
|
||||
"1. Make sure you have Rockset account and go to the web console to get the API key. Details can be found on [the website](https://rockset.com/docs/rest-api/). For the purpose of this notebook, we will assume you're using Rockset from `Oregon(us-west-2)`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7d77bbbe",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. Now you will need to create a Rockset collection to write to, use the Rockset web console to do this. For the purpose of this exercise, we will create a collection called `langchain_demo`. Since Rockset supports schemaless ingest, you don't need to inform us of the shape of metadata for your texts. However, you do need to decide on two columns upfront:\n",
|
||||
"- Where to store the text. We will use the column `description` for this.\n",
|
||||
"- Where to store the vector-embedding for the text. We will use the column `description_embedding` for this.\n",
|
||||
"\n",
|
||||
"Also you will need to inform Rockset that `description_embedding` is a vector-embedding, so that we can optimize its format. You can do this using a **Rockset ingest transformation** while creating your collection:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "3daa76ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"SELECT\n",
|
||||
" _input.* EXCEPT(_meta),\n",
|
||||
" VECTOR_ENFORCE(_input.description_embedding, #length_of_vector_embedding, 'float') as description_embedding\n",
|
||||
"FROM\n",
|
||||
" _input\n",
|
||||
" \n",
|
||||
" Configure the following [ingest transformation](https://rockset.com/docs/ingest-transformation/) to mark your embeddings field and take advantage of performance and storage optimizations:"
|
||||
"// We used OpenAI `text-embedding-ada-002` for this examples, where #length_of_vector_embedding = 1536"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7951c9cd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"3. Now let's install the [rockset-python-client](https://github.com/rockset/rockset-python-client). This is used by langchain to talk to the Rockset database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aac58387",
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "sql"
|
||||
}
|
||||
},
|
||||
"id": "2aac7ae6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SELECT _input.* EXCEPT(_meta), \n",
|
||||
"VECTOR_ENFORCE(_input.description_embedding, #length_of_vector_embedding, 'float') as description_embedding \n",
|
||||
"FROM _input"
|
||||
"!pip install rockset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "df380e1c",
|
||||
"id": "8600900d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. After creating your collection, use the console to retrieve an [API key](https://rockset.com/docs/iam/#users-api-keys-and-roles). For the purpose of this notebook, we assume you are using the `Oregon(us-west-2)` region.\n",
|
||||
"\n",
|
||||
"3. Install the [rockset-python-client](https://github.com/rockset/rockset-python-client) to enable LangChain to communicate directly with `Rockset`."
|
||||
"This is it! Now you're ready to start writing some python code to store vector embeddings in Rockset, and querying the database to find texts similar to your query! We support 3 distance functions: `COSINE_SIM`, `EUCLIDEAN_DIST` and `DOT_PRODUCT`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3bf2f818",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "00d16b83",
|
||||
"id": "a7b39626",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pip install rockset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e79550eb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## LangChain Tutorial\n",
|
||||
"\n",
|
||||
"Follow along in your own Python notebook to generate and store vector embeddings in Rockset.\n",
|
||||
"Start using Rockset to search for documents similar to your search queries.\n",
|
||||
"\n",
|
||||
"### 1. Define Key Variables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "29505c1e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "InitializationException",
|
||||
"evalue": "The rockset client was initialized incorrectly: An api key must be provided as a parameter to the RocksetClient or the Configuration object.",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mInitializationException\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[5], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m ROCKSET_API_KEY \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39menviron\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mROCKSET_API_KEY\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39m# Verify ROCKSET_API_KEY environment variable\u001b[39;00m\n\u001b[1;32m 5\u001b[0m ROCKSET_API_SERVER \u001b[39m=\u001b[39m rockset\u001b[39m.\u001b[39mRegions\u001b[39m.\u001b[39musw2a1 \u001b[39m# Verify Rockset region\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m rockset_client \u001b[39m=\u001b[39m rockset\u001b[39m.\u001b[39;49mRocksetClient(ROCKSET_API_SERVER, ROCKSET_API_KEY)\n\u001b[1;32m 8\u001b[0m COLLECTION_NAME\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mlangchain_demo\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 9\u001b[0m TEXT_KEY\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mdescription\u001b[39m\u001b[39m'\u001b[39m\n",
|
||||
"File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/rockset/rockset_client.py:242\u001b[0m, in \u001b[0;36mRocksetClient.__init__\u001b[0;34m(self, host, api_key, max_workers, config)\u001b[0m\n\u001b[1;32m 239\u001b[0m config\u001b[39m.\u001b[39mhost \u001b[39m=\u001b[39m host\n\u001b[1;32m 241\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m config\u001b[39m.\u001b[39mapi_key:\n\u001b[0;32m--> 242\u001b[0m \u001b[39mraise\u001b[39;00m InitializationException(\n\u001b[1;32m 243\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mAn api key must be provided as a parameter to the RocksetClient or the Configuration object.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 244\u001b[0m )\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapi_client \u001b[39m=\u001b[39m ApiClient(config, max_workers\u001b[39m=\u001b[39mmax_workers)\n\u001b[1;32m 248\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mAliases \u001b[39m=\u001b[39m AliasesApiWrapper(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapi_client)\n",
|
||||
"\u001b[0;31mInitializationException\u001b[0m: The rockset client was initialized incorrectly: An api key must be provided as a parameter to the RocksetClient or the Configuration object."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import rockset\n",
|
||||
"\n",
|
||||
"ROCKSET_API_KEY = os.environ.get(\"ROCKSET_API_KEY\") # Verify ROCKSET_API_KEY environment variable\n",
|
||||
"ROCKSET_API_SERVER = rockset.Regions.usw2a1 # Verify Rockset region\n",
|
||||
"# Make sure env variable ROCKSET_API_KEY is set\n",
|
||||
"ROCKSET_API_KEY = os.environ.get(\"ROCKSET_API_KEY\")\n",
|
||||
"ROCKSET_API_SERVER = (\n",
|
||||
" rockset.Regions.usw2a1\n",
|
||||
") # Make sure this points to the correct Rockset region\n",
|
||||
"rockset_client = rockset.RocksetClient(ROCKSET_API_SERVER, ROCKSET_API_KEY)\n",
|
||||
"\n",
|
||||
"COLLECTION_NAME='langchain_demo'\n",
|
||||
"TEXT_KEY='description'\n",
|
||||
"EMBEDDING_KEY='description_embedding'"
|
||||
"COLLECTION_NAME = \"langchain_demo\"\n",
|
||||
"TEXT_KEY = \"description\"\n",
|
||||
"EMBEDDING_KEY = \"description_embedding\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07625be2",
|
||||
"id": "474636a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2. Prepare Documents"
|
||||
"Now let's use this client to create a Rockset Langchain Vectorstore!\n",
|
||||
"\n",
|
||||
"### 1. Inserting texts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9740d8c4",
|
||||
"id": "0d73c5bb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31mRunning cells with '/opt/local/bin/python3.11' requires the ipykernel package.\n",
|
||||
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
|
||||
"\u001b[1;31mCommand: '/opt/local/bin/python3.11 -m pip install ipykernel -U --user --force-reinstall'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.vectorstores import Rockset\n",
|
||||
"\n",
|
||||
"loader = TextLoader('../../../state_of_the_union.txt')\n",
|
||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)"
|
||||
@@ -144,31 +134,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a068be18",
|
||||
"id": "1404cada",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 3. Insert Documents"
|
||||
"Now we have the documents we want to insert. Let's create a Rockset vectorstore and insert these docs into the Rockset collection. We will use `OpenAIEmbeddings` to create embeddings for the texts, but you're free to use whatever you want."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "85b6a6c5",
|
||||
"id": "63c98bac",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31mRunning cells with '/opt/local/bin/python3.11' requires the ipykernel package.\n",
|
||||
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
|
||||
"\u001b[1;31mCommand: '/opt/local/bin/python3.11 -m pip install ipykernel -U --user --force-reinstall'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = OpenAIEmbeddings() # Verify OPENAI_KEY environment variable\n",
|
||||
"# Make sure the environment variable OPENAI_API_KEY is set up\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"docsearch = Rockset(\n",
|
||||
" client=rockset_client,\n",
|
||||
@@ -178,38 +158,30 @@
|
||||
" embedding_key=EMBEDDING_KEY,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"ids=docsearch.add_texts(\n",
|
||||
"ids = docsearch.add_texts(\n",
|
||||
" texts=[d.page_content for d in docs],\n",
|
||||
" metadatas=[d.metadata for d in docs],\n",
|
||||
")"
|
||||
")\n",
|
||||
"\n",
|
||||
"## If you go to the Rockset console now, you should be able to see this docs along with the metadata `source`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "56eef48d",
|
||||
"id": "f1290844",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 4. Search for Similar Documents"
|
||||
"### 2. Searching similar texts\n",
|
||||
"\n",
|
||||
"Now let's try to search Rockset to find strings similar to our query string!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "0bbf3df0",
|
||||
"execution_count": null,
|
||||
"id": "96e73ac1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'docsearch' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m query \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mWhat did the president say about Ketanji Brown Jackson?\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m output \u001b[39m=\u001b[39m docsearch\u001b[39m.\u001b[39msimilarity_search_with_relevance_scores(query, \u001b[39m4\u001b[39m, Rockset\u001b[39m.\u001b[39mDistanceFunction\u001b[39m.\u001b[39mCOSINE_SIM)\n\u001b[1;32m 4\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39moutput length:\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mlen\u001b[39m(output))\n\u001b[1;32m 5\u001b[0m \u001b[39mfor\u001b[39;00m d, dist \u001b[39min\u001b[39;00m output:\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'docsearch' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"output = docsearch.similarity_search_with_relevance_scores(\n",
|
||||
@@ -217,7 +189,7 @@
|
||||
")\n",
|
||||
"print(\"output length:\", len(output))\n",
|
||||
"for d, dist in output:\n",
|
||||
" print(dist, d.metadata, d.page_content[:20] + '...')\n",
|
||||
" print(dist, d.metadata, d.page_content[:20] + \"...\")\n",
|
||||
"\n",
|
||||
"##\n",
|
||||
"# output length: 4\n",
|
||||
@@ -229,16 +201,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7037a22f",
|
||||
"id": "5e15d630",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 5. Search for Similar Documents with Filtering"
|
||||
"You can also use a where filter to prune your search space. You can add filters on text key, or any of the metadata fields. \n",
|
||||
"\n",
|
||||
"> **Note**: Since Rockset stores each metadata field as a separate column internally, these filters are much faster than other vector databases which store all metadata as a single JSON.\n",
|
||||
"\n",
|
||||
"For eg, to find all texts NOT containing the substring \"and\", you can use the following code:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b64a290f",
|
||||
"id": "c1c44d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -250,7 +226,7 @@
|
||||
")\n",
|
||||
"print(\"output length:\", len(output))\n",
|
||||
"for d, dist in output:\n",
|
||||
" print(dist, d.metadata, d.page_content[:20] + '...')\n",
|
||||
" print(dist, d.metadata, d.page_content[:20] + \"...\")\n",
|
||||
"\n",
|
||||
"##\n",
|
||||
"# output length: 4\n",
|
||||
@@ -263,13 +239,12 @@
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "13a52b38",
|
||||
"id": "0765b822",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 6. [Optional] Delete Inserted Documents\n",
|
||||
"### 3. [Optional] Drop all inserted documents\n",
|
||||
"\n",
|
||||
"You must have the unique ID associated with each document to delete them from your collection.\n",
|
||||
"Define IDs when inserting documents with `Rockset.add_texts()`. Rockset will otherwise generate a unique ID for each document. Regardless, `Rockset.add_texts()` returns the IDs of inserted documents.\n",
|
||||
"In order to delete texts from the Rockset collection, you need to know the unique ID associated with each document inside Rockset. These ids can either be supplied directly by the user while inserting the texts (in the `Rockset.add_texts()` function), else Rockset will generate a unique ID or each document. Either way, `Rockset.add_texts()` returns the ids for the inserted documents.\n",
|
||||
"\n",
|
||||
"To delete these docs, simply use the `Rockset.delete_texts()` function."
|
||||
]
|
||||
@@ -277,7 +252,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f755924",
|
||||
"id": "31738966",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -286,15 +261,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d468f431",
|
||||
"id": "03fa12a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Summary\n",
|
||||
"## Congratulations!\n",
|
||||
"\n",
|
||||
"In this tutorial, we successfully created a `Rockset` collection, `inserted` documents with OpenAI embeddings, and searched for similar documents with and without metadata filters.\n",
|
||||
"Voila! In this example you successfuly created a Rockset collection, inserted documents along with their OpenAI vector embeddings, and searched for similar docs both with and without any metadata filters.\n",
|
||||
"\n",
|
||||
"Keep an eye on https://rockset.com/ for future updates in this space."
|
||||
"Keep an eye on https://rockset.com/blog/introducing-vector-search-on-rockset/ for future updates in this space!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2763dddb-e87d-4d3b-b0bf-c246b0573d87",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -313,7 +296,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.6"
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
274
docs/extras/modules/agents/agent_types/anthropic_agent.ipynb
Normal file
274
docs/extras/modules/agents/agent_types/anthropic_agent.ipynb
Normal file
@@ -0,0 +1,274 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "9926203f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
||||
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n",
|
||||
"os.environ[\"LANGCHAIN_API_KEY\"] = \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "45bc4149",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_instructions = \"\"\"You are a helpful assistant. Help the user answer any questions.\n",
|
||||
"\n",
|
||||
"You have access to the following tools:\n",
|
||||
"\n",
|
||||
"{tools}\n",
|
||||
"\n",
|
||||
"In order to use a tool, you can use <tool></tool> and <tool_input></tool_input> tags. \\\n",
|
||||
"You will then get back a response in the form <observation></observation>\n",
|
||||
"For example, if you have a tool called 'search' that could run a google search, in order to search for the weather in SF you would respond:\n",
|
||||
"\n",
|
||||
"<tool>search</tool><tool_input>weather in SF</tool_input>\n",
|
||||
"<observation>64 degrees</observation>\n",
|
||||
"\n",
|
||||
"When you are done, respond with a final answer between <final_answer></final_answer>. For example:\n",
|
||||
"\n",
|
||||
"<final_answer>The weather in SF is 64 degrees</final_answer>\n",
|
||||
"\n",
|
||||
"Begin!\n",
|
||||
"\n",
|
||||
"Question: {question}\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4da4c0d2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatAnthropic\n",
|
||||
"from langchain.prompts import ChatPromptTemplate, AIMessagePromptTemplate\n",
|
||||
"from langchain.agents import tool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "b81e9120",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ChatAnthropic(model=\"claude-2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "5271f612",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt_template = ChatPromptTemplate.from_template(agent_instructions) + AIMessagePromptTemplate.from_template(\"{intermediate_steps}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "83780d81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = prompt_template | model.bind(stop=[\"</tool_input>\", \"</final_answer>\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "c091d0e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tool\n",
|
||||
"def search(query: str) -> str:\n",
|
||||
" \"\"\"Search things about current events.\"\"\"\n",
|
||||
" return \"32 degrees\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "1e81b05d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tool_list = [search]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "5f0d986f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import Tool, AgentExecutor, BaseSingleActionAgent\n",
|
||||
"from typing import List, Tuple, Any, Union\n",
|
||||
"from langchain.schema import AgentAction, AgentFinish\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class AnthropicAgent(BaseSingleActionAgent):\n",
|
||||
" \n",
|
||||
" tools: List[Tool]\n",
|
||||
" chain: Any\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def input_keys(self):\n",
|
||||
" return [\"input\"]\n",
|
||||
"\n",
|
||||
" def plan(\n",
|
||||
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
|
||||
" ) -> Union[AgentAction, AgentFinish]:\n",
|
||||
" \"\"\"Given input, decided what to do.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" intermediate_steps: Steps the LLM has taken to date,\n",
|
||||
" along with observations\n",
|
||||
" **kwargs: User inputs.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" Action specifying what tool to use.\n",
|
||||
" \"\"\"\n",
|
||||
" log = \"\"\n",
|
||||
" for action, observation in intermediate_steps:\n",
|
||||
" log += f\"<tool>{action.tool}</tool><tool_input>{action.tool_input}</tool_input><observation>{observation}</observation>\"\n",
|
||||
" tools = \"\"\n",
|
||||
" for tool in self.tools:\n",
|
||||
" tools += f\"{tool.name}: {tool.description}\\n\"\n",
|
||||
" response = self.chain.invoke({\"intermediate_steps\": log, \"tools\": tools, \"question\": kwargs[\"input\"]})\n",
|
||||
" if \"</tool>\" in response.content:\n",
|
||||
" t, ti = response.content.split(\"</tool>\")\n",
|
||||
" _t = t.split(\"<tool>\")[1]\n",
|
||||
" _ti = ti.split(\"<tool_input>\")[1]\n",
|
||||
" return AgentAction(tool=_t, tool_input=_ti, log=response.content)\n",
|
||||
" elif \"<final_answer>\" in response.content:\n",
|
||||
" t, ti = response.content.split(\"<final_answer>\")\n",
|
||||
" return AgentFinish(return_values={\"output\": ti}, log=response.content)\n",
|
||||
" else:\n",
|
||||
" raise ValueError\n",
|
||||
"\n",
|
||||
" async def aplan(\n",
|
||||
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
|
||||
" ) -> Union[AgentAction, AgentFinish]:\n",
|
||||
" \"\"\"Given input, decided what to do.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" intermediate_steps: Steps the LLM has taken to date,\n",
|
||||
" along with observations\n",
|
||||
" **kwargs: User inputs.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" Action specifying what tool to use.\n",
|
||||
" \"\"\"\n",
|
||||
" raise ValueError"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "315361c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent = AnthropicAgent(tools=tool_list, chain=chain)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "bca6096f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_executor = AgentExecutor(agent=agent, tools=tool_list, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "71b872b1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m <tool>search</tool>\n",
|
||||
"<tool_input>weather in new york\u001b[0m\u001b[36;1m\u001b[1;3m32 degrees\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"<final_answer>The weather in New York is 32 degrees\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The weather in New York is 32 degrees'"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_executor.run(\"whats the weather in New york?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cca87246",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3c284df8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# XML Agent\n",
|
||||
"\n",
|
||||
"Some language models (like Anthropic's Claude) are particularly good at reasoning/writing XML. This goes over how to use an agent that uses XML when prompting. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f9d2ead2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import XMLAgent, tool, AgentExecutor\n",
|
||||
"from langchain.chat_models import ChatAnthropic\n",
|
||||
"from langchain.chains import LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "ebadf04f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ChatAnthropic(model=\"claude-2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "6ce9f9a5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tool\n",
|
||||
"def search(query: str) -> str:\n",
|
||||
" \"\"\"Search things about current events.\"\"\"\n",
|
||||
" return \"32 degrees\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "c589944e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tool_list = [search]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "2d8454be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = LLMChain(\n",
|
||||
" llm=model,\n",
|
||||
" prompt=XMLAgent.get_default_prompt(),\n",
|
||||
" output_parser=XMLAgent.get_default_output_parser()\n",
|
||||
")\n",
|
||||
"agent = XMLAgent(tools=tool_list, llm_chain=chain)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "bca6096f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_executor = AgentExecutor(agent=agent, tools=tool_list, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "71b872b1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m <tool>search</tool>\n",
|
||||
"<tool_input>weather in New York\u001b[0m\u001b[36;1m\u001b[1;3m32 degrees\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"<final_answer>The weather in New York is 32 degrees\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The weather in New York is 32 degrees'"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_executor.run(\"whats the weather in New york?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cca87246",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -9,7 +9,7 @@
|
||||
"source": [
|
||||
"# ArangoDB QA chain\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/langchain-ai/langchain/blob/master/docs/extras/use_cases/graph/graph_arangodb_qa.ipynb)\n",
|
||||
"[](https://colab.research.google.com/github/hwchase17/langchain/blob/master/docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb)\n",
|
||||
"\n",
|
||||
"This notebook shows how to use LLMs to provide a natural language interface to an [ArangoDB](https://github.com/arangodb/arangodb#readme) database."
|
||||
]
|
||||
|
||||
@@ -10,11 +10,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.graphs import NeptuneGraph\n",
|
||||
"from langchain.graphs.neptune_graph import NeptuneGraph\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"host = \"<neptune-host>\"\n",
|
||||
@@ -26,23 +26,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The Austin airport has 98 outgoing routes.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.chains import NeptuneOpenCypherQAChain\n",
|
||||
"from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
"\n",
|
||||
@@ -53,22 +42,8 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"name": "python"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"source": [
|
||||
"# Tree of Thought (ToT) example\n",
|
||||
"\n",
|
||||
"The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the paper [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)"
|
||||
"The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the papaer [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "488d6ee8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Adding Memory to SQL Database Chain\n",
|
||||
"\n",
|
||||
"This notebook shows how to add memory to a SQLDatabaseChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "6ef6918e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.utilities import SQLDatabase\n",
|
||||
"from langchain_experimental.sql import SQLDatabaseChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "600aedb5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set up the SQLDatabase and LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b54c24c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")\n",
|
||||
"llm = OpenAI(temperature=0, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "96a1543f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set up the memory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "fc103f91",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"memory = ConversationBufferMemory()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "af31b91d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we need add to a place for memory in the prompt template"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "debcff82",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"PROMPT_SUFFIX = \"\"\"Only use the following tables:\n",
|
||||
"{table_info}\n",
|
||||
"\n",
|
||||
"Previous Conversation:\n",
|
||||
"{history}\n",
|
||||
"\n",
|
||||
"Question: {input}\"\"\"\n",
|
||||
"\n",
|
||||
"_DEFAULT_TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
|
||||
"\n",
|
||||
"Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n",
|
||||
"\n",
|
||||
"Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
|
||||
"\n",
|
||||
"Use the following format:\n",
|
||||
"\n",
|
||||
"Question: Question here\n",
|
||||
"SQLQuery: SQL Query to run\n",
|
||||
"SQLResult: Result of the SQLQuery\n",
|
||||
"Answer: Final answer here\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"PROMPT = PromptTemplate.from_template(\n",
|
||||
" _DEFAULT_TEMPLATE + PROMPT_SUFFIX,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "7f6115f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "b4753f69",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
|
||||
"name one employee\n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3mSELECT FirstName, LastName FROM Employee LIMIT 1\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[('Andrew', 'Adams')]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3mAndrew Adams\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Andrew Adams'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db_chain.run(\"name one employee\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "aa1100c8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
|
||||
"how many letters in their name?\n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3mSELECT LENGTH(FirstName) + LENGTH(LastName) AS 'NameLength' FROM Employee WHERE FirstName = 'Andrew' AND LastName = 'Adams'\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[(11,)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3mAndrew Adams has 11 letters in their name.\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Andrew Adams has 11 letters in their name.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db_chain.run(\"how many letters in their name?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "11525db8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,206 +0,0 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, DefaultDict, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from pydantic import root_validator
|
||||
|
||||
prompt = """In addition to responding, you can use tools. \
|
||||
You have access to the following tools.
|
||||
|
||||
{tools}
|
||||
|
||||
In order to use a tool, you can use <tool></tool> to specify the name, \
|
||||
and the <tool_input></tool_input> tags to specify the parameters. \
|
||||
Each parameter should be passed in as <$param_name>$value</$param_name>, \
|
||||
Where $param_name is the name of the specific parameter, and $value \
|
||||
is the value for that parameter.
|
||||
|
||||
You will then get back a response in the form <observation></observation>
|
||||
For example, if you have a tool called 'search' that accepts a single \
|
||||
parameter 'query' that could run a google search, in order to search \
|
||||
for the weather in SF you would respond:
|
||||
|
||||
<tool>search</tool><tool_input><query>weather in SF</query></tool_input>
|
||||
<observation>64 degrees</observation>"""
|
||||
|
||||
|
||||
class TagParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
"""A heavy-handed solution, but it's fast for prototyping.
|
||||
|
||||
Might be re-implemented later to restrict scope to the limited grammar, and
|
||||
more efficiency.
|
||||
|
||||
Uses an HTML parser to parse a limited grammar that allows
|
||||
for syntax of the form:
|
||||
|
||||
INPUT -> JUNK? VALUE*
|
||||
JUNK -> JUNK_CHARACTER+
|
||||
JUNK_CHARACTER -> whitespace | ,
|
||||
VALUE -> <IDENTIFIER>DATA</IDENTIFIER> | OBJECT
|
||||
OBJECT -> <IDENTIFIER>VALUE+</IDENTIFIER>
|
||||
IDENTIFIER -> [a-Z][a-Z0-9_]*
|
||||
DATA -> .*
|
||||
|
||||
Interprets the data to allow repetition of tags and recursion
|
||||
to support representation of complex types.
|
||||
|
||||
^ Just another approximately wrong grammar specification.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.parse_data: DefaultDict[str, List[Any]] = defaultdict(list)
|
||||
self.stack: List[DefaultDict[str, List[str]]] = [self.parse_data]
|
||||
self.success = True
|
||||
self.depth = 0
|
||||
self.data: Optional[str] = None
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: Any) -> None:
|
||||
"""Hook when a new tag is encountered."""
|
||||
self.depth += 1
|
||||
self.stack.append(defaultdict(list))
|
||||
self.data = None
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
"""Hook when a tag is closed."""
|
||||
self.depth -= 1
|
||||
top_of_stack = dict(self.stack.pop(-1)) # Pop the dictionary we don't need it
|
||||
|
||||
# If a lead node
|
||||
is_leaf = self.data is not None
|
||||
# Annoying to type here, code is tested, hopefully OK
|
||||
value = self.data if is_leaf else top_of_stack
|
||||
# Difficult to type this correctly with mypy (maybe impossible?)
|
||||
# Can be nested indefinitely, so requires self referencing type
|
||||
self.stack[-1][tag].append(value) # type: ignore
|
||||
# Reset the data so we if we encounter a sequence of end tags, we
|
||||
# don't confuse an outer end tag for belonging to a leaf node.
|
||||
self.data = None
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
"""Hook when handling data."""
|
||||
stripped_data = data.strip()
|
||||
# The only data that's allowed is whitespace or a comma surrounded by whitespace
|
||||
if self.depth == 0 and stripped_data not in (",", ""):
|
||||
# If this is triggered the parse should be considered invalid.
|
||||
self.success = False
|
||||
if stripped_data: # ignore whitespace-only strings
|
||||
self.data = stripped_data
|
||||
|
||||
|
||||
def _destrip(tool_input: Any) -> Any:
|
||||
if isinstance(tool_input, dict):
|
||||
return {k: _destrip(v) for k, v in tool_input.items()}
|
||||
elif isinstance(tool_input, list):
|
||||
if isinstance(tool_input[0], str):
|
||||
if len(tool_input) == 1:
|
||||
return tool_input[0]
|
||||
else:
|
||||
raise ValueError
|
||||
elif isinstance(tool_input[0], dict):
|
||||
return [_destrip(v) for v in tool_input]
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class AnthropicFunctions(BaseChatModel):
|
||||
model: ChatAnthropic
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
return {"model": ChatAnthropic(**values)}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
forced = False
|
||||
function_call = ""
|
||||
if "functions" in kwargs:
|
||||
content = prompt.format(tools=json.dumps(kwargs["functions"], indent=2))
|
||||
system = SystemMessage(content=content)
|
||||
messages = [system] + messages
|
||||
del kwargs["functions"]
|
||||
if stop is None:
|
||||
stop = ["</tool_input>"]
|
||||
else:
|
||||
stop.append("</tool_input>")
|
||||
if "function_call" in kwargs:
|
||||
forced = True
|
||||
function_call = kwargs["function_call"]["name"]
|
||||
AIMessage(content=f"<tool>{function_call}</tool>")
|
||||
del kwargs["function_call"]
|
||||
else:
|
||||
if "function_call" in kwargs:
|
||||
raise ValueError(
|
||||
"if `function_call` provided, `functions` must also be"
|
||||
)
|
||||
response = self.model.predict_messages(
|
||||
messages, stop=stop, callbacks=run_manager, **kwargs
|
||||
)
|
||||
completion = response.content
|
||||
if forced:
|
||||
tag_parser = TagParser()
|
||||
tag_parser.feed(completion.strip() + "</tool_input>")
|
||||
v1 = tag_parser.parse_data["tool_input"][0]
|
||||
kwargs = {
|
||||
"function_call": {
|
||||
"name": function_call,
|
||||
"arguments": json.dumps(_destrip(v1)),
|
||||
}
|
||||
}
|
||||
message = AIMessage(content="", additional_kwargs=kwargs)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
elif "<tool>" in completion:
|
||||
tag_parser = TagParser()
|
||||
tag_parser.feed(completion.strip() + "</tool_input>")
|
||||
msg = completion.split("<tool>")[0]
|
||||
v1 = tag_parser.parse_data["tool_input"][0]
|
||||
kwargs = {
|
||||
"function_call": {
|
||||
"name": tag_parser.parse_data["tool"][0],
|
||||
"arguments": json.dumps(_destrip(v1)),
|
||||
}
|
||||
}
|
||||
message = AIMessage(content=msg, additional_kwargs=kwargs)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
else:
|
||||
return ChatResult(generations=[ChatGeneration(message=response)])
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "anthropic_functions"
|
||||
@@ -112,9 +112,6 @@ class SQLDatabaseChain(Chain):
|
||||
"table_info": table_info,
|
||||
"stop": ["\nSQLResult:"],
|
||||
}
|
||||
if self.memory is not None:
|
||||
for k in self.memory.memory_variables:
|
||||
llm_inputs[k] = inputs[k]
|
||||
intermediate_steps: List = []
|
||||
try:
|
||||
intermediate_steps.append(llm_inputs) # input: sql generation
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-experimental"
|
||||
version = "0.0.8"
|
||||
version = "0.0.7"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
|
||||
@@ -34,7 +34,6 @@ from langchain.chains.graph_qa.cypher import GraphCypherQAChain
|
||||
from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain
|
||||
from langchain.chains.graph_qa.kuzu import KuzuQAChain
|
||||
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
|
||||
from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain
|
||||
from langchain.chains.graph_qa.sparql import GraphSparqlQAChain
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -102,7 +101,6 @@ __all__ = [
|
||||
"MultiRouteChain",
|
||||
"NatBotChain",
|
||||
"NebulaGraphQAChain",
|
||||
"NeptuneOpenCypherQAChain",
|
||||
"OpenAIModerationChain",
|
||||
"OpenAPIEndpointChain",
|
||||
"QAGenerationChain",
|
||||
|
||||
@@ -26,15 +26,23 @@ default_header_template = {
|
||||
class AsyncHtmlLoader(BaseLoader):
|
||||
"""Loads HTML asynchronously."""
|
||||
|
||||
web_paths: List[str]
|
||||
|
||||
requests_per_second: int = 2
|
||||
"""Max number of concurrent requests to make."""
|
||||
|
||||
requests_kwargs: Dict[str, Any] = {}
|
||||
"""kwargs for requests"""
|
||||
|
||||
raise_for_status: bool = False
|
||||
"""Raise an exception if http status code denotes an error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_path: Union[str, List[str]],
|
||||
header_template: Optional[dict] = None,
|
||||
verify_ssl: Optional[bool] = True,
|
||||
proxies: Optional[dict] = None,
|
||||
requests_per_second: int = 2,
|
||||
requests_kwargs: Dict[str, Any] = {},
|
||||
raise_for_status: bool = False,
|
||||
):
|
||||
"""Initialize with webpage path."""
|
||||
|
||||
@@ -66,10 +74,6 @@ class AsyncHtmlLoader(BaseLoader):
|
||||
if proxies:
|
||||
self.session.proxies.update(proxies)
|
||||
|
||||
self.requests_per_second = requests_per_second
|
||||
self.requests_kwargs = requests_kwargs
|
||||
self.raise_for_status = raise_for_status
|
||||
|
||||
async def _fetch(
|
||||
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
||||
) -> str:
|
||||
|
||||
@@ -49,18 +49,7 @@ class GitLoader(BaseLoader):
|
||||
if not os.path.exists(self.repo_path) and self.clone_url is None:
|
||||
raise ValueError(f"Path {self.repo_path} does not exist")
|
||||
elif self.clone_url:
|
||||
# If the repo_path already contains a git repository, verify that it's the
|
||||
# same repository as the one we're trying to clone.
|
||||
if os.path.isdir(os.path.join(self.repo_path, ".git")):
|
||||
repo = Repo(self.repo_path)
|
||||
# If the existing repository is not the same as the one we're trying to
|
||||
# clone, raise an error.
|
||||
if repo.remotes.origin.url != self.clone_url:
|
||||
raise ValueError(
|
||||
"A different repository is already cloned at this path."
|
||||
)
|
||||
else:
|
||||
repo = Repo.clone_from(self.clone_url, self.repo_path)
|
||||
repo = Repo.clone_from(self.clone_url, self.repo_path)
|
||||
repo.git.checkout(self.branch)
|
||||
else:
|
||||
repo = Repo(self.repo_path)
|
||||
|
||||
@@ -469,7 +469,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
|
||||
Examples:
|
||||
|
||||
Instantiation from a list of message templates:
|
||||
Instantiation from a list of role strings and templates:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -488,6 +488,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
("human", "Hello, how are you?"),
|
||||
])
|
||||
|
||||
Instantiation from a list message templates:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
template = ChatPromptTemplate.from_messages([
|
||||
("human", "Hello, how are you?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
("human", "That's good to hear."),
|
||||
])
|
||||
|
||||
|
||||
Args:
|
||||
messages: sequence of message representations.
|
||||
A message can be represented using the following formats:
|
||||
|
||||
@@ -190,9 +190,12 @@ class FewShotChatMessagePromptTemplate(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain.prompts import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
ChatPromptTemplate
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate
|
||||
)
|
||||
|
||||
examples = [
|
||||
@@ -200,23 +203,24 @@ class FewShotChatMessagePromptTemplate(
|
||||
{"input": "2+3", "output": "5"},
|
||||
]
|
||||
|
||||
example_prompt = ChatPromptTemplate.from_messages(
|
||||
[('human', '{input}'), ('ai', '{output}')]
|
||||
)
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
examples=examples,
|
||||
# This is a prompt template used to format each individual example.
|
||||
example_prompt=example_prompt,
|
||||
example_prompt=(
|
||||
HumanMessagePromptTemplate.from_template("{input}")
|
||||
+ AIMessagePromptTemplate.from_template("{output}")
|
||||
),
|
||||
)
|
||||
|
||||
final_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
('system', 'You are a helpful AI Assistant'),
|
||||
few_shot_prompt,
|
||||
('human', '{input}'),
|
||||
]
|
||||
|
||||
final_prompt = (
|
||||
SystemMessagePromptTemplate.from_template(
|
||||
"You are a helpful AI Assistant"
|
||||
)
|
||||
+ few_shot_prompt
|
||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||
)
|
||||
|
||||
final_prompt.format(input="What is 4+4?")
|
||||
|
||||
Prompt template with dynamically selected examples:
|
||||
|
||||
@@ -34,18 +34,18 @@ class SearchQueries(BaseModel):
|
||||
|
||||
DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template="""<<SYS>> \n You are an assistant tasked with improving Google search \
|
||||
results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that \
|
||||
are similar to this question. The output should be a numbered list of questions \
|
||||
and each should have a question mark at the end: \n\n {question} [/INST]""",
|
||||
template="""<<SYS>> \n You are an assistant tasked with improving Google search
|
||||
results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that
|
||||
are similar to this question. The output should be a numbered list of questions
|
||||
and each should have a question mark at the end: \n\n {question} [/INST]""",
|
||||
)
|
||||
|
||||
DEFAULT_SEARCH_PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template="""You are an assistant tasked with improving Google search \
|
||||
results. Generate THREE Google search queries that are similar to \
|
||||
this question. The output should be a numbered list of questions and each \
|
||||
should have a question mark at the end: {question}""",
|
||||
template="""You are an assistant tasked with improving Google search
|
||||
results. Generate THREE Google search queries that are similar to
|
||||
this question. The output should be a numbered list of questions and each
|
||||
should have a question mark at the end: {question}""",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -108,10 +108,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> List[Output]:
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return [self.invoke(inputs[0], configs[0])]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(executor.map(self.invoke, inputs, configs))
|
||||
|
||||
@@ -763,140 +759,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
key: str
|
||||
input: Any
|
||||
|
||||
|
||||
class RouterRunnable(
|
||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||
):
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
|
||||
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
||||
super().__init__(runnables=runnables)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[RouterInput, Other]:
|
||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return runnable.invoke(actual_input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return await runnable.ainvoke(actual_input, config)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
lambda runnable, input, config: runnable.invoke(input, config),
|
||||
runnables,
|
||||
actual_inputs,
|
||||
configs,
|
||||
)
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
return await _gather_with_concurrency(
|
||||
max_concurrency,
|
||||
*(
|
||||
runnable.ainvoke(input, config)
|
||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||
),
|
||||
)
|
||||
|
||||
def stream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
yield from runnable.stream(actual_input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
async for output in runnable.astream(actual_input, config):
|
||||
yield output
|
||||
|
||||
|
||||
def _patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||
) -> RunnableConfig:
|
||||
|
||||
@@ -280,9 +280,7 @@ class VectorStore(ABC):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
self.similarity_search_with_relevance_scores, query, k=k, **kwargs
|
||||
)
|
||||
func = partial(self.similarity_search_with_relevance_scores, query, k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
async def asimilarity_search(
|
||||
@@ -293,7 +291,7 @@ class VectorStore(ABC):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search, query, k=k, **kwargs)
|
||||
func = partial(self.similarity_search, query, k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
@@ -318,7 +316,7 @@ class VectorStore(ABC):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
|
||||
func = partial(self.similarity_search_by_vector, embedding, k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
@@ -361,12 +359,7 @@ class VectorStore(ABC):
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
self.max_marginal_relevance_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
**kwargs,
|
||||
self.max_marginal_relevance_search, query, k, fetch_k, lambda_mult, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import faiss python package. "
|
||||
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
|
||||
"Please install it with `pip install faiss` "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
)
|
||||
return faiss
|
||||
|
||||
@@ -23,6 +23,7 @@ class Rockset(VectorStore):
|
||||
See: https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details
|
||||
|
||||
Everything below assumes `commons` Rockset workspace.
|
||||
TODO: Add support for workspace args.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -49,7 +50,6 @@ class Rockset(VectorStore):
|
||||
collection_name: str,
|
||||
text_key: str,
|
||||
embedding_key: str,
|
||||
workspace: str = "commons",
|
||||
):
|
||||
"""Initialize with Rockset client.
|
||||
Args:
|
||||
@@ -82,7 +82,6 @@ class Rockset(VectorStore):
|
||||
self._embeddings = embeddings
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
@@ -304,7 +303,7 @@ class Rockset(VectorStore):
|
||||
where_str = f"WHERE {where_str}\n" if where_str else ""
|
||||
return f"""\
|
||||
SELECT * EXCEPT({self._embedding_key}), {distance_str}
|
||||
FROM {self._workspace}.{self._collection_name}
|
||||
FROM {self._collection_name}
|
||||
{where_str}\
|
||||
ORDER BY dist {distance_func.order_by()}
|
||||
LIMIT {str(k)}
|
||||
@@ -312,7 +311,7 @@ LIMIT {str(k)}
|
||||
|
||||
def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]:
|
||||
add_doc_res = self._client.Documents.add_documents(
|
||||
collection=self._collection_name, data=batch, workspace=self._workspace
|
||||
collection=self._collection_name, data=batch
|
||||
)
|
||||
return [doc_status._id for doc_status in add_doc_res.data]
|
||||
|
||||
@@ -329,5 +328,4 @@ LIMIT {str(k)}
|
||||
self._client.Documents.delete_documents(
|
||||
collection=self._collection_name,
|
||||
data=[DeleteDocumentsRequestData(id=i) for i in ids],
|
||||
workspace=self._workspace,
|
||||
)
|
||||
|
||||
11
libs/langchain/poetry.lock
generated
11
libs/langchain/poetry.lock
generated
@@ -3269,14 +3269,14 @@ smmap = ">=3.0.1,<6"
|
||||
|
||||
[[package]]
|
||||
name = "gitpython"
|
||||
version = "3.1.32"
|
||||
version = "3.1.31"
|
||||
description = "GitPython is a Python library used to interact with Git repositories"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"},
|
||||
{file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"},
|
||||
{file = "GitPython-3.1.31-py3-none-any.whl", hash = "sha256:f04893614f6aa713a60cbbe1e6a97403ef633103cdd0ef5eb6efe0deb98dbe8d"},
|
||||
{file = "GitPython-3.1.31.tar.gz", hash = "sha256:8ce3bcf69adfdf7c7d503e78fd3b1c492af782d58893b650adb2ac8912ddd573"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4652,7 +4652,6 @@ optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
||||
files = [
|
||||
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
|
||||
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -13230,7 +13229,7 @@ clarifai = ["clarifai"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xinference", "zep-python"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "geopandas", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xinference", "zep-python"]
|
||||
javascript = ["esprima"]
|
||||
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
@@ -13240,4 +13239,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "ef2b1d30e0fa872ce764c8a4cbc6e0a460bc9391a6465ee29d657e83b5459391"
|
||||
content-hash = "5b1c718874d76c0e3b4023b2bceebe11a5e26e5e05d6797acf91b01b0438b2f7"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain"
|
||||
version = "0.0.248"
|
||||
version = "0.0.247"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@@ -126,7 +126,6 @@ amadeus = {version = ">=8.1.0", optional = true}
|
||||
geopandas = {version = "^0.13.1", optional = true}
|
||||
xinference = {version = "^0.0.6", optional = true}
|
||||
python-arango = {version = "^7.5.9", optional = true}
|
||||
gitpython = {version = "^3.1.32", optional = true}
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
# The only dependencies that should be added are
|
||||
@@ -360,7 +359,6 @@ extended_testing = [
|
||||
"geopandas",
|
||||
"jinja2",
|
||||
"xinference",
|
||||
"gitpython",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@@ -34,12 +34,12 @@ def test_sql_query() -> None:
|
||||
|
||||
client = rockset.RocksetClient(host, api_key)
|
||||
|
||||
col_1 = "Rockset is a real-time analytics database"
|
||||
col_1 = "Rockset is a real-time analytics database which enables queries on massive, semi-structured data without operational burden. Rockset is serverless and fully managed. It offloads the work of managing configuration, cluster provisioning, denormalization, and shard / index management. Rockset is also SOC 2 Type II compliant and offers encryption at rest and in flight, securing and protecting any sensitive data. Most teams can ingest data into Rockset and start executing queries in less than 15 minutes." # noqa: E501
|
||||
col_2 = 2
|
||||
col_3 = "e903e069-b0b5-4b80-95e2-86471b41f55f"
|
||||
id = 7320132
|
||||
|
||||
"""Run a simple SQL query"""
|
||||
"""Run a simple SQL query query"""
|
||||
loader = RocksetLoader(
|
||||
client,
|
||||
rockset.models.QueryRequestSql(
|
||||
|
||||
@@ -33,7 +33,6 @@ logger = logging.getLogger(__name__)
|
||||
#
|
||||
# See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details.
|
||||
|
||||
workspace = "langchain_tests"
|
||||
collection_name = "langchain_demo"
|
||||
text_key = "description"
|
||||
embedding_key = "description_embedding"
|
||||
@@ -72,9 +71,10 @@ class TestRockset:
|
||||
"Deleting all existing documents from the Rockset collection %s",
|
||||
collection_name,
|
||||
)
|
||||
query = f"select _id from {workspace}.{collection_name}"
|
||||
|
||||
query_response = client.Queries.query(sql={"query": query})
|
||||
query_response = client.Queries.query(
|
||||
sql={"query": "select _id from {}".format(collection_name)}
|
||||
)
|
||||
ids = [
|
||||
str(r["_id"])
|
||||
for r in getattr(
|
||||
@@ -85,13 +85,12 @@ class TestRockset:
|
||||
client.Documents.delete_documents(
|
||||
collection=collection_name,
|
||||
data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids],
|
||||
workspace=workspace,
|
||||
)
|
||||
|
||||
embeddings = ConsistentFakeEmbeddings()
|
||||
embeddings.embed_documents(fake_texts)
|
||||
cls.rockset_vectorstore = Rockset(
|
||||
client, embeddings, collection_name, text_key, embedding_key, workspace
|
||||
client, embeddings, collection_name, text_key, embedding_key
|
||||
)
|
||||
|
||||
def test_rockset_insert_and_search(self) -> None:
|
||||
@@ -128,9 +127,9 @@ class TestRockset:
|
||||
)
|
||||
vector_str = ",".join(map(str, vector))
|
||||
expected = f"""\
|
||||
SELECT * EXCEPT({embedding_key}), \
|
||||
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
|
||||
FROM {workspace}.{collection_name}
|
||||
SELECT * EXCEPT(description_embedding), \
|
||||
COSINE_SIM(description_embedding, [{vector_str}]) as dist
|
||||
FROM langchain_demo
|
||||
ORDER BY dist DESC
|
||||
LIMIT 4
|
||||
"""
|
||||
@@ -146,9 +145,9 @@ LIMIT 4
|
||||
)
|
||||
vector_str = ",".join(map(str, vector))
|
||||
expected = f"""\
|
||||
SELECT * EXCEPT({embedding_key}), \
|
||||
COSINE_SIM({embedding_key}, [{vector_str}]) as dist
|
||||
FROM {workspace}.{collection_name}
|
||||
SELECT * EXCEPT(description_embedding), \
|
||||
COSINE_SIM(description_embedding, [{vector_str}]) as dist
|
||||
FROM langchain_demo
|
||||
WHERE age >= 10
|
||||
ORDER BY dist DESC
|
||||
LIMIT 4
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
def test_import() -> None:
|
||||
from langchain.chains import NeptuneOpenCypherQAChain # noqa: F401
|
||||
@@ -1,65 +0,0 @@
|
||||
import os
|
||||
|
||||
import py
|
||||
import pytest
|
||||
|
||||
from langchain.document_loaders import GitLoader
|
||||
|
||||
|
||||
def init_repo(tmpdir: py.path.local, dir_name: str) -> str:
|
||||
from git import Repo
|
||||
|
||||
repo_dir = tmpdir.mkdir(dir_name)
|
||||
repo = Repo.init(repo_dir)
|
||||
git = repo.git
|
||||
git.checkout(b="main")
|
||||
|
||||
git.config("user.name", "Test User")
|
||||
git.config("user.email", "test@example.com")
|
||||
|
||||
sample_file = "file.txt"
|
||||
with open(os.path.join(repo_dir, sample_file), "w") as f:
|
||||
f.write("content")
|
||||
git.add([sample_file])
|
||||
git.commit(m="Initial commit")
|
||||
|
||||
return repo_dir
|
||||
|
||||
|
||||
@pytest.mark.requires("git")
|
||||
def test_load_twice(tmpdir: py.path.local) -> None:
|
||||
"""
|
||||
Test that loading documents twice from the same repository does not raise an error.
|
||||
"""
|
||||
|
||||
clone_url = init_repo(tmpdir, "remote_repo")
|
||||
|
||||
repo_path = tmpdir.mkdir("local_repo").strpath
|
||||
loader = GitLoader(repo_path=repo_path, clone_url=clone_url)
|
||||
|
||||
documents = loader.load()
|
||||
assert len(documents) == 1
|
||||
|
||||
documents = loader.load()
|
||||
assert len(documents) == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("git")
|
||||
def test_clone_different_repo(tmpdir: py.path.local) -> None:
|
||||
"""
|
||||
Test that trying to clone a different repository into a directory already
|
||||
containing a clone raises a ValueError.
|
||||
"""
|
||||
|
||||
clone_url = init_repo(tmpdir, "remote_repo")
|
||||
|
||||
repo_path = tmpdir.mkdir("local_repo").strpath
|
||||
loader = GitLoader(repo_path=repo_path, clone_url=clone_url)
|
||||
|
||||
documents = loader.load()
|
||||
assert len(documents) == 1
|
||||
|
||||
other_clone_url = init_repo(tmpdir, "other_remote_repo")
|
||||
other_loader = GitLoader(repo_path=repo_path, clone_url=other_clone_url)
|
||||
with pytest.raises(ValueError):
|
||||
other_loader.load()
|
||||
@@ -1,2 +0,0 @@
|
||||
def test_import() -> None:
|
||||
from langchain.graphs import NeptuneGraph # noqa: F401
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for the time-weighted retriever class."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
@@ -139,11 +139,7 @@ def test_get_salient_docs(
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
docs_and_scores = time_weighted_retriever.get_salient_docs(query)
|
||||
want = [(doc, 0.5) for doc in _get_example_memories()]
|
||||
assert isinstance(docs_and_scores, dict)
|
||||
assert len(docs_and_scores) == len(want)
|
||||
for k, doc in docs_and_scores.items():
|
||||
assert doc in want
|
||||
|
||||
|
||||
def test_get_relevant_documents(
|
||||
@@ -151,17 +147,7 @@ def test_get_relevant_documents(
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
|
||||
want = [(doc, 0.5) for doc in _get_example_memories()]
|
||||
assert isinstance(relevant_documents, list)
|
||||
assert len(relevant_documents) == len(want)
|
||||
now = datetime.now()
|
||||
for doc in relevant_documents:
|
||||
# assert that the last_accessed_at is close to now.
|
||||
assert now - timedelta(hours=1) < doc.metadata["last_accessed_at"] <= now
|
||||
|
||||
# assert that the last_accessed_at in the memory stream is updated.
|
||||
for d in time_weighted_retriever.memory_stream:
|
||||
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now
|
||||
|
||||
|
||||
def test_add_documents(
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -23,7 +23,6 @@ from langchain.schema.document import Document
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
RouterRunnable,
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
@@ -34,38 +33,16 @@ from langchain.schema.runnable import (
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution.
|
||||
It replaces run ids with deterministic UUIDs for snapshotting."""
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Run] = []
|
||||
self.uuids_map: Dict[UUID, UUID] = {}
|
||||
self.uuids_generator = (
|
||||
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
||||
)
|
||||
|
||||
def _replace_uuid(self, uuid: UUID) -> UUID:
|
||||
if uuid not in self.uuids_map:
|
||||
self.uuids_map[uuid] = next(self.uuids_generator)
|
||||
return self.uuids_map[uuid]
|
||||
|
||||
def _copy_run(self, run: Run) -> Run:
|
||||
return run.copy(
|
||||
update={
|
||||
"id": self._replace_uuid(run.id),
|
||||
"parent_run_id": self.uuids_map[run.parent_run_id]
|
||||
if run.parent_run_id
|
||||
else None,
|
||||
"child_runs": [self._copy_run(child) for child in run.child_runs],
|
||||
}
|
||||
)
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
self.runs.append(self._copy_run(run))
|
||||
self.runs.append(run)
|
||||
|
||||
|
||||
class FakeRunnable(Runnable[str, int]):
|
||||
@@ -101,6 +78,20 @@ class FakeRetriever(BaseRetriever):
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fixed_uuids(mocker: MockerFixture) -> MockerFixture._Patcher:
|
||||
"""Note this mock only works with `import uuid; uuid.uuid4()`,
|
||||
it does not work with `from uuid import uuid4; uuid4()`."""
|
||||
|
||||
# Disable tracing to avoid fixed UUIDs causing tracing errors.
|
||||
mocker.patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"})
|
||||
|
||||
side_effect = (
|
||||
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
|
||||
)
|
||||
return mocker.patch("uuid.uuid4", side_effect=side_effect)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
@@ -215,13 +206,13 @@ async def test_prompt() -> None:
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_chat_model(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
chat = FakeListChatModel(responses=["foo", "bar"])
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
@@ -260,7 +251,7 @@ async def test_prompt_with_chat_model(
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="foo"),
|
||||
AIMessage(content="bar"),
|
||||
AIMessage(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
@@ -281,16 +272,7 @@ async def test_prompt_with_chat_model(
|
||||
]
|
||||
),
|
||||
]
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
r
|
||||
for r in tracer.runs
|
||||
if r.parent_run_id is None and len(r.child_runs) == 2
|
||||
]
|
||||
)
|
||||
== 2
|
||||
), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(chat_spy)
|
||||
|
||||
@@ -300,7 +282,7 @@ async def test_prompt_with_chat_model(
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [AIMessage(content="foo")]
|
||||
] == [AIMessage(content="bar")]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@@ -313,7 +295,7 @@ async def test_prompt_with_chat_model(
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@@ -404,7 +386,7 @@ async def test_prompt_with_llm(
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_prompt_with_chat_model_and_parser(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@@ -442,7 +424,7 @@ def test_prompt_with_chat_model_and_parser(
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_dict_prompt_llm(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
@@ -505,16 +487,13 @@ What is your name?"""
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 4
|
||||
map_run = parent_run.child_runs[0]
|
||||
assert map_run.name == "RunnableMap"
|
||||
assert len(map_run.child_runs) == 3
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||
def test_seq_prompt_dict(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
prompt = (
|
||||
@@ -565,64 +544,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 3
|
||||
map_run = parent_run.child_runs[2]
|
||||
assert map_run.name == "RunnableMap"
|
||||
assert len(map_run.child_runs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_router_runnable(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
chain1 = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
chain2 = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
router = RouterRunnable({"math": chain1, "english": chain2})
|
||||
chain: Runnable = {
|
||||
"key": lambda x: x["key"],
|
||||
"input": {"question": lambda x: x["question"]},
|
||||
} | router
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = chain.batch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = await chain.abatch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
# Test invoke
|
||||
router_spy = mocker.spy(router.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
|
||||
== "4"
|
||||
)
|
||||
assert router_spy.call_args.args[1] == {
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||
def test_seq_prompt_map(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
prompt = (
|
||||
@@ -680,12 +608,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 3
|
||||
map_run = parent_run.child_runs[2]
|
||||
assert map_run.name == "RunnableMap"
|
||||
assert len(map_run.child_runs) == 3
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
def test_bind_bind() -> None:
|
||||
|
||||
Reference in New Issue
Block a user