mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-20 13:28:53 +00:00
Compare commits
113 Commits
vwp/tools_
...
vwp/align_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a77540fbe | ||
|
|
5763d26b9e | ||
|
|
54076f21b2 | ||
|
|
f3d727147a | ||
|
|
c825bd45d8 | ||
|
|
3b10dabe4d | ||
|
|
21f0719c9e | ||
|
|
0094879504 | ||
|
|
396a4b0458 | ||
|
|
b76f8cd252 | ||
|
|
2b7d51706e | ||
|
|
0e7e1e66f9 | ||
|
|
af302d99f0 | ||
|
|
1c2e2e93c8 | ||
|
|
0e06e6e34a | ||
|
|
37b819cfa5 | ||
|
|
ec00fc71a8 | ||
|
|
ceec14f1bf | ||
|
|
4aa03b3e01 | ||
|
|
7e6097964e | ||
|
|
5104f9b08c | ||
|
|
871c295b4c | ||
|
|
07627b57ec | ||
|
|
6f514361be | ||
|
|
b7f4a410a3 | ||
|
|
eb12242495 | ||
|
|
aa1c3df5cf | ||
|
|
f7af565510 | ||
|
|
3ec77607dc | ||
|
|
597e87abac | ||
|
|
dfdb8279a6 | ||
|
|
1999294349 | ||
|
|
6732ef9d35 | ||
|
|
2ba18a0096 | ||
|
|
48997b35c9 | ||
|
|
9d7cfbcfcc | ||
|
|
8fb767b8c6 | ||
|
|
69db22be32 | ||
|
|
5f0248f0fb | ||
|
|
580f1b2a48 | ||
|
|
be794e0360 | ||
|
|
659e94fc9c | ||
|
|
74a95629a3 | ||
|
|
1781d611f8 | ||
|
|
ca0dfd38f8 | ||
|
|
4630916e8c | ||
|
|
57a6982007 | ||
|
|
cdbc4cda37 | ||
|
|
4f501e59ec | ||
|
|
c850a4d406 | ||
|
|
aa9cf24a54 | ||
|
|
ffac033150 | ||
|
|
8fc1c43e5d | ||
|
|
1deacb4f0a | ||
|
|
621ab11734 | ||
|
|
4bb95ad529 | ||
|
|
8f5996a31c | ||
|
|
68c19e1452 | ||
|
|
df0e1f85da | ||
|
|
99f74ff7d9 | ||
|
|
fe5db65628 | ||
|
|
59a4a8b34b | ||
|
|
73aedeed07 | ||
|
|
e7d27d52f6 | ||
|
|
1c73dc6408 | ||
|
|
b9d0e88584 | ||
|
|
7482cc218c | ||
|
|
cc247960a4 | ||
|
|
2e2be677c9 | ||
|
|
dca5772ed9 | ||
|
|
5adfda8507 | ||
|
|
704e0b98d8 | ||
|
|
cc6902f817 | ||
|
|
2f1ab146d5 | ||
|
|
9bcb2af86a | ||
|
|
cdc9c6a2fd | ||
|
|
d0fa3cf798 | ||
|
|
d80017f51f | ||
|
|
bf0bbc8f2c | ||
|
|
27f1463f4a | ||
|
|
f7b05e7348 | ||
|
|
bf795bffdb | ||
|
|
906488f87e | ||
|
|
7a01742895 | ||
|
|
cef046ae18 | ||
|
|
5e53336c7d | ||
|
|
95ae3c5f4b | ||
|
|
fa9c5ac78d | ||
|
|
d5ef266842 | ||
|
|
3fdfa5d576 | ||
|
|
e41a70eb59 | ||
|
|
71db9c97c6 | ||
|
|
042415eee4 | ||
|
|
37cc3d2e63 | ||
|
|
828c96072c | ||
|
|
edbd3c7964 | ||
|
|
f553d28a11 | ||
|
|
e8e8ca163b | ||
|
|
c9d5525485 | ||
|
|
8f4f90cdae | ||
|
|
612f928323 | ||
|
|
7c211d2438 | ||
|
|
6a0abccf4d | ||
|
|
beb0f6fd60 | ||
|
|
219b618a5b | ||
|
|
fcd174cf43 | ||
|
|
74f46262d0 | ||
|
|
058273174a | ||
|
|
1a4c4a24f2 | ||
|
|
f6c98a7c1e | ||
|
|
e0cb4c3005 | ||
|
|
97cabb40ae | ||
|
|
37b68dc8f2 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -144,8 +144,4 @@ wandb/
|
||||
/.ruff_cache/
|
||||
|
||||
*.pkl
|
||||
*.bin
|
||||
|
||||
# integration test artifacts
|
||||
data_map*
|
||||
\[('_type', 'fake'), ('stop', None)]
|
||||
*.bin
|
||||
@@ -4,8 +4,6 @@
|
||||
|
||||
[](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [](https://github.com/hwchase17/langchain/actions/workflows/linkcheck.yml) [](https://pepy.tech/project/langchain) [](https://opensource.org/licenses/MIT) [](https://twitter.com/langchainai) [](https://discord.gg/6adMQxSpJS)
|
||||
|
||||
Looking for the JS/TS version? Check out [LangChain.js](https://github.com/hwchase17/langchainjs).
|
||||
|
||||
**Production Support:** As you move your LangChains into production, we'd love to offer more comprehensive support.
|
||||
Please fill out [this form](https://forms.gle/57d8AmXBYp8PP8tZA) and we'll set up a dedicated support Slack channel.
|
||||
|
||||
|
||||
BIN
docs/_static/MetalDash.png
vendored
BIN
docs/_static/MetalDash.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 3.5 MiB |
@@ -1,10 +1,14 @@
|
||||
# Deployments
|
||||
|
||||
So, you've created a really cool chain - now what? How do you deploy it and make it easily shareable with the world?
|
||||
So you've made a really cool chain - now what? How do you deploy it and make it easily sharable with the world?
|
||||
|
||||
This section covers several options for that. Note that these options are meant for quick deployment of prototypes and demos, not for production systems. If you need help with the deployment of a production system, please contact us directly.
|
||||
This section covers several options for that.
|
||||
Note that these are meant as quick deployment options for prototypes and demos, and not for production systems.
|
||||
If you are looking for help with deployment of a production system, please contact us directly.
|
||||
|
||||
What follows is a list of template GitHub repositories designed to be easily forked and modified to use your chain. This list is far from exhaustive, and we are EXTREMELY open to contributions here.
|
||||
What follows is a list of template GitHub repositories aimed that are intended to be
|
||||
very easy to fork and modify to use your chain.
|
||||
This is far from an exhaustive list of options, and we are EXTREMELY open to contributions here.
|
||||
|
||||
## [Streamlit](https://github.com/hwchase17/langchain-streamlit-template)
|
||||
|
||||
@@ -43,11 +47,12 @@ A minimal example on how to deploy LangChain to Google Cloud Run.
|
||||
|
||||
## [SteamShip](https://github.com/steamship-core/steamship-langchain/)
|
||||
|
||||
This repository contains LangChain adapters for Steamship, enabling LangChain developers to rapidly deploy their apps on Steamship. This includes: production-ready endpoints, horizontal scaling across dependencies, persistent storage of app state, multi-tenancy support, etc.
|
||||
This repository contains LangChain adapters for Steamship, enabling LangChain developers to rapidly deploy their apps on Steamship.
|
||||
This includes: production ready endpoints, horizontal scaling across dependencies, persistant storage of app state, multi-tenancy support, etc.
|
||||
|
||||
## [Langchain-serve](https://github.com/jina-ai/langchain-serve)
|
||||
|
||||
This repository allows users to serve local chains and agents as RESTful, gRPC, or WebSocket APIs, thanks to [Jina](https://docs.jina.ai/). Deploy your chains & agents with ease and enjoy independent scaling, serverless and autoscaling APIs, as well as a Streamlit playground on Jina AI Cloud.
|
||||
This repository allows users to serve local chains and agents as RESTful, gRPC, or Websocket APIs thanks to [Jina](https://docs.jina.ai/). Deploy your chains & agents with ease and enjoy independent scaling, serverless and autoscaling APIs, as well as a Streamlit playground on Jina AI Cloud.
|
||||
|
||||
## [BentoML](https://github.com/ssheng/BentoChain)
|
||||
|
||||
@@ -55,4 +60,4 @@ This repository provides an example of how to deploy a LangChain application wit
|
||||
|
||||
## [Databutton](https://databutton.com/home?new-data-app=true)
|
||||
|
||||
These templates serve as examples of how to build, deploy, and share LangChain applications using Databutton. You can create user interfaces with Streamlit, automate tasks by scheduling Python code, and store files and data in the built-in store. Examples include a Chatbot interface with conversational memory, a Personal search engine, and a starter template for LangChain apps. Deploying and sharing is just one click away.
|
||||
These templates serve as examples of how to build, deploy, and share LangChain applications using Databutton. You can create user interfaces with Streamlit, automate tasks by scheduling Python code, and store files and data in the built-in store. Examples include Chatbot interface with conversational memory, Personal search engine, and a starter template for LangChain apps. Deploying and sharing is one click.
|
||||
|
||||
@@ -61,6 +61,7 @@
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks import AimCallbackHandler, StdOutCallbackHandler"
|
||||
]
|
||||
},
|
||||
@@ -108,8 +109,8 @@
|
||||
" experiment_name=\"scenario 1: OpenAI LLM\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"callbacks = [StdOutCallbackHandler(), aim_callback]\n",
|
||||
"llm = OpenAI(temperature=0, callbacks=callbacks)"
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), aim_callback])\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -176,7 +177,7 @@
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"\n",
|
||||
"test_prompts = [\n",
|
||||
" {\"title\": \"documentary about good video games that push the boundary of game design\"},\n",
|
||||
@@ -248,12 +249,13 @@
|
||||
],
|
||||
"source": [
|
||||
"# scenario 3 - Agent with Tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" callbacks=callbacks,\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
|
||||
|
||||
@@ -79,6 +79,7 @@
|
||||
"source": [
|
||||
"from datetime import datetime\n",
|
||||
"from langchain.callbacks import ClearMLCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"# Setup and use the ClearML Callback\n",
|
||||
@@ -92,9 +93,9 @@
|
||||
" complexity_metrics=True,\n",
|
||||
" stream_logs=True\n",
|
||||
")\n",
|
||||
"callbacks = [StdOutCallbackHandler(), clearml_callback]\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), clearml_callback])\n",
|
||||
"# Get the OpenAI model ready to go\n",
|
||||
"llm = OpenAI(temperature=0, callbacks=callbacks)"
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -522,12 +523,13 @@
|
||||
"from langchain.agents import AgentType\n",
|
||||
"\n",
|
||||
"# SCENARIO 2 - Agent with Tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" callbacks=callbacks,\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
" \"Who is the wife of the person who sang summer of 69?\"\n",
|
||||
|
||||
@@ -121,6 +121,7 @@
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"comet_callback = CometCallbackHandler(\n",
|
||||
@@ -130,8 +131,8 @@
|
||||
" tags=[\"llm\"],\n",
|
||||
" visualizations=[\"dep\"],\n",
|
||||
")\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9, callbacks=callbacks, verbose=True)\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"\n",
|
||||
"llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\", \"Tell me a fact\"] * 3)\n",
|
||||
"print(\"LLM result\", llm_result)\n",
|
||||
@@ -152,6 +153,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
@@ -162,14 +164,15 @@
|
||||
" stream_logs=True,\n",
|
||||
" tags=[\"synopsis-chain\"],\n",
|
||||
")\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"\n",
|
||||
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"\n",
|
||||
"test_prompts = [{\"title\": \"Documentary about Bigfoot in Paris\"}]\n",
|
||||
"print(synopsis_chain.apply(test_prompts))\n",
|
||||
@@ -191,6 +194,7 @@
|
||||
"source": [
|
||||
"from langchain.agents import initialize_agent, load_tools\n",
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"comet_callback = CometCallbackHandler(\n",
|
||||
@@ -199,15 +203,15 @@
|
||||
" stream_logs=True,\n",
|
||||
" tags=[\"agent\"],\n",
|
||||
")\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=\"zero-shot-react-description\",\n",
|
||||
" callbacks=callbacks,\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
@@ -251,6 +255,7 @@
|
||||
"from rouge_score import rouge_scorer\n",
|
||||
"\n",
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
@@ -293,10 +298,10 @@
|
||||
" tags=[\"custom_metrics\"],\n",
|
||||
" custom_metrics=rouge_score.compute_metric,\n",
|
||||
")\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9)\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"\n",
|
||||
"test_prompts = [\n",
|
||||
" {\n",
|
||||
@@ -318,7 +323,7 @@
|
||||
" \"\"\"\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"print(synopsis_chain.apply(test_prompts, callbacks=callbacks))\n",
|
||||
"print(synopsis_chain.apply(test_prompts))\n",
|
||||
"comet_callback.flush_tracker(synopsis_chain, finish=True)"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
This page covers how to use the `GPT4All` wrapper within LangChain. The tutorial is divided into two parts: installation and setup, followed by usage with an example.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the Python package with `pip install pyllamacpp`
|
||||
- Download a [GPT4All model](https://github.com/nomic-ai/pyllamacpp#supported-model) and place it in your desired directory
|
||||
|
||||
@@ -29,16 +28,16 @@ To stream the model's predictions, add in a CallbackManager.
|
||||
|
||||
```python
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
# There are many CallbackHandlers supported, such as
|
||||
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
|
||||
|
||||
callbacks = [StreamingStdOutCallbackHandler()]
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8)
|
||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, verbose=True)
|
||||
|
||||
# Generate text. Tokens are streamed through the callback manager.
|
||||
model("Once upon a time, ", callbacks=callbacks)
|
||||
model("Once upon a time, ")
|
||||
```
|
||||
|
||||
## Model File
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
# LanceDB
|
||||
|
||||
This page covers how to use [LanceDB](https://github.com/lancedb/lancedb) within LangChain.
|
||||
It is broken into two parts: installation and setup, and then references to specific LanceDB wrappers.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the Python SDK with `pip install lancedb`
|
||||
|
||||
## Wrappers
|
||||
|
||||
### VectorStore
|
||||
|
||||
There exists a wrapper around LanceDB databases, allowing you to use it as a vectorstore,
|
||||
whether for semantic search or example selection.
|
||||
|
||||
To import this vectorstore:
|
||||
|
||||
```python
|
||||
from langchain.vectorstores import LanceDB
|
||||
```
|
||||
|
||||
For a more detailed walkthrough of the LanceDB wrapper, see [this notebook](../modules/indexes/vectorstores/examples/lancedb.ipynb)
|
||||
@@ -1,26 +0,0 @@
|
||||
# Metal
|
||||
|
||||
This page covers how to use [Metal](https://getmetal.io) within LangChain.
|
||||
|
||||
## What is Metal?
|
||||
|
||||
Metal is a managed retrieval & memory platform built for production. Easily index your data into `Metal` and run semantic search and retrieval on it.
|
||||
|
||||

|
||||
|
||||
## Quick start
|
||||
|
||||
Get started by [creating a Metal account](https://app.getmetal.io/signup).
|
||||
|
||||
Then, you can easily take advantage of the `MetalRetriever` class to start retrieving your data for semantic search, prompting context, etc. This class takes a `Metal` instance and a dictionary of parameters to pass to the Metal API.
|
||||
|
||||
```python
|
||||
from langchain.retrievers import MetalRetriever
|
||||
from metal_sdk.metal import Metal
|
||||
|
||||
|
||||
metal = Metal("API_KEY", "CLIENT_ID", "INDEX_ID");
|
||||
retriever = MetalRetriever(metal, params={"limit": 2})
|
||||
|
||||
docs = retriever.get_relevant_documents("search term")
|
||||
```
|
||||
@@ -1,19 +0,0 @@
|
||||
# PipelineAI
|
||||
|
||||
This page covers how to use the PipelineAI ecosystem within LangChain.
|
||||
It is broken into two parts: installation and setup, and then references to specific PipelineAI wrappers.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install with `pip install pipeline-ai`
|
||||
- Get a Pipeline Cloud api key and set it as an environment variable (`PIPELINE_API_KEY`)
|
||||
|
||||
## Wrappers
|
||||
|
||||
### LLM
|
||||
|
||||
There exists a PipelineAI LLM wrapper, which you can access with
|
||||
|
||||
```python
|
||||
from langchain.llms import PipelineAI
|
||||
```
|
||||
@@ -50,6 +50,7 @@
|
||||
"source": [
|
||||
"from datetime import datetime\n",
|
||||
"from langchain.callbacks import WandbCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI"
|
||||
]
|
||||
},
|
||||
@@ -195,8 +196,8 @@
|
||||
" name=\"llm\",\n",
|
||||
" tags=[\"test\"],\n",
|
||||
")\n",
|
||||
"callbacks = [StdOutCallbackHandler(), wandb_callback]\n",
|
||||
"llm = OpenAI(temperature=0, callbacks=callbacks)"
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), wandb_callback])\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -483,7 +484,7 @@
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"\n",
|
||||
"test_prompts = [\n",
|
||||
" {\n",
|
||||
@@ -576,15 +577,16 @@
|
||||
],
|
||||
"source": [
|
||||
"# SCENARIO 3 - Agent with Tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\",\n",
|
||||
" callbacks=callbacks,\n",
|
||||
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
|
||||
")\n",
|
||||
"wandb_callback.flush_tracker(agent, reset=False, finish=True)"
|
||||
]
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
"from langchain.agents import AgentType\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.tracers import LangChainTracer\n",
|
||||
"from aiohttp import ClientSession\n",
|
||||
"\n",
|
||||
@@ -306,14 +307,14 @@
|
||||
" # To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n",
|
||||
" # but you must manually close the client session at the end of your program/event loop\n",
|
||||
" aiosession = ClientSession()\n",
|
||||
" callbacks = [StdOutCallbackHandler()]\n",
|
||||
" for _ in questions:\n",
|
||||
" llm = OpenAI(temperature=0)\n",
|
||||
" async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n",
|
||||
" manager = CallbackManager([StdOutCallbackHandler()])\n",
|
||||
" llm = OpenAI(temperature=0, callback_manager=manager)\n",
|
||||
" async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n",
|
||||
" agents.append(\n",
|
||||
" initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n",
|
||||
" initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
|
||||
" )\n",
|
||||
" tasks = [async_agent.arun(q, callbacks=callbacks) for async_agent, q in zip(agents, questions)]\n",
|
||||
" tasks = [async_agent.arun(q) for async_agent, q in zip(agents, questions)]\n",
|
||||
" await asyncio.gather(*tasks)\n",
|
||||
" await aiosession.close()\n",
|
||||
"\n",
|
||||
@@ -375,14 +376,14 @@
|
||||
"aiosession = ClientSession()\n",
|
||||
"tracer = LangChainTracer()\n",
|
||||
"tracer.load_default_session()\n",
|
||||
"callbacks = [StdOutCallbackHandler(), tracer]\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), tracer])\n",
|
||||
"\n",
|
||||
"# Pass the manager into the llm if you want llm calls traced.\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager)\n",
|
||||
"\n",
|
||||
"async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n",
|
||||
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n",
|
||||
"await async_agent.arun(questions[0], callbacks=callbacks)\n",
|
||||
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
|
||||
"await async_agent.arun(questions[0])\n",
|
||||
"await aiosession.close()"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -373,7 +373,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tools = get_tools(\"whats the weather?\")\n",
|
||||
"tool_names = [tool.name for tool in tools]\n",
|
||||
"agent = LLMSingleActionAgent(\n",
|
||||
" llm_chain=llm_chain, \n",
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -5,158 +5,57 @@
|
||||
"id": "8f210ec3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Shell Tool\n",
|
||||
"\n",
|
||||
"Giving agents access to the shell is powerful (though risky outside a sandboxed environment).\n",
|
||||
"\n",
|
||||
"The LLM can use it to execute any shell commands. A common use case for this is letting the LLM interact with your local file system."
|
||||
"# Bash\n",
|
||||
"It can often be useful to have an LLM generate bash commands, and then run them. A common use case for this is letting the LLM interact with your local file system. We provide an easy util to execute bash commands."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f7b3767b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.tools import ShellTool\n",
|
||||
"\n",
|
||||
"shell_tool = ShellTool()"
|
||||
"from langchain.utilities import BashProcess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "c92ac832-556b-4f66-baa4-b78f965dfba0",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello World!\n",
|
||||
"\n",
|
||||
"real\t0m0.000s\n",
|
||||
"user\t0m0.000s\n",
|
||||
"sys\t0m0.000s\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(shell_tool.run({\"commands\": [\"echo 'Hello World!'\", \"time\"]}))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2fa952fc",
|
||||
"id": "cf1c92f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"### Use with Agents\n",
|
||||
"\n",
|
||||
"As with all tools, these can be given to an agent to accomplish more complex tasks. Let's have the agent fetch some links from a web page."
|
||||
"bash = BashProcess()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "851fee9f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"id": "2fa952fc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mQuestion: What is the task?\n",
|
||||
"Thought: We need to download the langchain.com webpage and extract all the URLs from it. Then we need to sort the URLs and return them.\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"shell\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"commands\": [\n",
|
||||
" \"curl -s https://langchain.com | grep -o 'http[s]*://[^\\\" ]*' | sort\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\u001b[0m"
|
||||
"bash.ipynb\n",
|
||||
"google_search.ipynb\n",
|
||||
"python.ipynb\n",
|
||||
"requests.ipynb\n",
|
||||
"serpapi.ipynb\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mhttps://blog.langchain.dev/\n",
|
||||
"https://discord.gg/6adMQxSpJS\n",
|
||||
"https://docs.langchain.com/docs/\n",
|
||||
"https://github.com/hwchase17/chat-langchain\n",
|
||||
"https://github.com/hwchase17/langchain\n",
|
||||
"https://github.com/hwchase17/langchainjs\n",
|
||||
"https://github.com/sullivan-sean/chat-langchainjs\n",
|
||||
"https://js.langchain.com/docs/\n",
|
||||
"https://python.langchain.com/en/latest/\n",
|
||||
"https://twitter.com/langchainai\n",
|
||||
"\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mThe URLs have been successfully extracted and sorted. We can return the list of URLs as the final answer.\n",
|
||||
"Final Answer: [\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'[\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.agents import initialize_agent\n",
|
||||
"from langchain.agents import AgentType\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"\n",
|
||||
"shell_tool.description = shell_tool.description + f\"args {shell_tool.args}\".replace(\"{\", \"{{\").replace(\"}\", \"}}\")\n",
|
||||
"self_ask_with_search = initialize_agent([shell_tool], llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)\n",
|
||||
"self_ask_with_search.run(\"Download the langchain.com webpage and grep for all urls. Return only a sorted list of them. Be sure to use double quotes.\")"
|
||||
"print(bash.run(\"ls\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d0ea3ac-0890-4e39-9cec-74bd80b4b8b8",
|
||||
"id": "851fee9f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@@ -178,7 +77,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.tools import DuckDuckGoSearchRun"
|
||||
"from langchain.tools import DuckDuckGoSearchTool"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -37,7 +37,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search = DuckDuckGoSearchRun()"
|
||||
"search = DuckDuckGoSearchTool()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# File System Tools\n",
|
||||
"\n",
|
||||
"LangChain provides tools for interacting with a local file system out of the box. This notebook walks through some of them.\n",
|
||||
"\n",
|
||||
"Note: these tools are not recommended for use outside a sandboxed environment! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, we'll import the tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.tools.file_management import (\n",
|
||||
" ReadFileTool,\n",
|
||||
" CopyFileTool,\n",
|
||||
" DeleteFileTool,\n",
|
||||
" MoveFileTool,\n",
|
||||
" WriteFileTool,\n",
|
||||
" ListDirectoryTool,\n",
|
||||
")\n",
|
||||
"from langchain.agents.agent_toolkits import FileManagementToolkit\n",
|
||||
"from tempfile import TemporaryDirectory\n",
|
||||
"\n",
|
||||
"# We'll make a temporary directory to avoid clutter\n",
|
||||
"working_directory = TemporaryDirectory()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## The FileManagementToolkit\n",
|
||||
"\n",
|
||||
"If you want to provide all the file tooling to your agent, it's easy to do so with the toolkit. We'll pass the temporary directory in as a root directory as a workspace for the LLM.\n",
|
||||
"\n",
|
||||
"It's recommended to always pass in a root directory, since without one, it's easy for the LLM to pollute the working directory, and without one, there isn't any validation against\n",
|
||||
"straightforward prompt injection."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[CopyFileTool(name='copy_file', description='Create a copy of a file in a specified location', args_schema=<class 'langchain.tools.file_management.copy.FileCopyInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" DeleteFileTool(name='file_delete', description='Delete a file', args_schema=<class 'langchain.tools.file_management.delete.FileDeleteInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" FileSearchTool(name='file_search', description='Recursively search for files in a subdirectory that match the regex pattern', args_schema=<class 'langchain.tools.file_management.file_search.FileSearchInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" MoveFileTool(name='move_file', description='Move or rename a file from one location to another', args_schema=<class 'langchain.tools.file_management.move.FileMoveInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" ReadFileTool(name='read_file', description='Read file from disk', args_schema=<class 'langchain.tools.file_management.read.ReadFileInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" WriteFileTool(name='write_file', description='Write file to disk', args_schema=<class 'langchain.tools.file_management.write.WriteFileInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" ListDirectoryTool(name='list_directory', description='List files and directories in a specified folder', args_schema=<class 'langchain.tools.file_management.list_dir.DirectoryListingInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"toolkit = FileManagementToolkit(root_dir=str(working_directory.name)) # If you don't provide a root_dir, operations will default to the current working directory\n",
|
||||
"toolkit.get_tools()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Selecting File System Tools\n",
|
||||
"\n",
|
||||
"If you only want to select certain tools, you can pass them in as arguments when initializing the toolkit, or you can individually initialize the desired tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[ReadFileTool(name='read_file', description='Read file from disk', args_schema=<class 'langchain.tools.file_management.read.ReadFileInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" WriteFileTool(name='write_file', description='Write file to disk', args_schema=<class 'langchain.tools.file_management.write.WriteFileInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug'),\n",
|
||||
" ListDirectoryTool(name='list_directory', description='List files and directories in a specified folder', args_schema=<class 'langchain.tools.file_management.list_dir.DirectoryListingInput'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1156f4350>, root_dir='/var/folders/gf/6rnp_mbx5914kx7qmmh7xzmw0000gn/T/tmpxb8c3aug')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tools = FileManagementToolkit(root_dir=str(working_directory.name), selected_tools=[\"read_file\", \"write_file\", \"list_directory\"]).get_tools()\n",
|
||||
"tools"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'File written successfully to example.txt.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"read_tool, write_tool, list_tool = tools\n",
|
||||
"write_tool.run({\"file_path\": \"example.txt\", \"text\": \"Hello World!\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'example.txt'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# List files in the working directory\n",
|
||||
"list_tool.run({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,8 +24,8 @@
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"echo \"Hello World\"\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['echo \"Hello World\"']\u001b[0m\n",
|
||||
"```\u001b[0m['```bash', 'echo \"Hello World\"', '```']\n",
|
||||
"\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
@@ -65,7 +65,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -93,7 +93,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -107,8 +107,8 @@
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"printf \"Hello World\\n\"\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['printf \"Hello World\\\\n\"']\u001b[0m\n",
|
||||
"```\u001b[0m['```bash', 'printf \"Hello World\\\\n\"', '```']\n",
|
||||
"\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
@@ -120,7 +120,7 @@
|
||||
"'Hello World\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -132,114 +132,6 @@
|
||||
"\n",
|
||||
"bash_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Persistent Terminal\n",
|
||||
"\n",
|
||||
"By default, the chain will run in a separate subprocess each time it is called. This behavior can be changed by instantiating with a persistent bash process."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
|
||||
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"ls\n",
|
||||
"cd ..\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mapi.ipynb\t\t\tllm_summarization_checker.ipynb\n",
|
||||
"constitutional_chain.ipynb\tmoderation.ipynb\n",
|
||||
"llm_bash.ipynb\t\t\topenai_openapi.yaml\n",
|
||||
"llm_checker.ipynb\t\topenapi.ipynb\n",
|
||||
"llm_math.ipynb\t\t\tpal.ipynb\n",
|
||||
"llm_requests.ipynb\t\tsqlite.ipynb\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.utilities.bash import BashProcess\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"persistent_process = BashProcess(persistent=True)\n",
|
||||
"bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n",
|
||||
"\n",
|
||||
"text = \"List the current directory then move up a level.\"\n",
|
||||
"\n",
|
||||
"bash_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
|
||||
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"ls\n",
|
||||
"cd ..\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mexamples\t\tgetting_started.ipynb\tindex_examples\n",
|
||||
"generic\t\t\thow_to_guides.rst\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Run the same command again and see that the state is maintained between calls\n",
|
||||
"bash_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -258,7 +150,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "593f7553-7038-498e-96d4-8255e5ce34f0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Creating a custom Chain\n",
|
||||
"\n",
|
||||
"To implement your own custom chain you can subclass `BaseChain` and implement the following methods:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
|
||||
"metadata": {
|
||||
"tags": [],
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"from typing import Any, Dict, List, Optional\n",
|
||||
"\n",
|
||||
"from pydantic import Extra\n",
|
||||
"\n",
|
||||
"from langchain.base_language import BaseLanguageModel\n",
|
||||
"from langchain.callbacks.manager import (\n",
|
||||
" AsyncCallbackManagerForChainRun,\n",
|
||||
" CallbackManagerForChainRun,\n",
|
||||
")\n",
|
||||
"from langchain.chains.base import Chain\n",
|
||||
"from langchain.prompts.base import BasePromptTemplate\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyCustomChain(Chain):\n",
|
||||
" \"\"\"\n",
|
||||
" An example of a custom chain.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" prompt: BasePromptTemplate\n",
|
||||
" \"\"\"Prompt object to use.\"\"\"\n",
|
||||
" llm: BaseLanguageModel\n",
|
||||
" output_key: str = \"text\" #: :meta private:\n",
|
||||
"\n",
|
||||
" class Config:\n",
|
||||
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
|
||||
"\n",
|
||||
" extra = Extra.forbid\n",
|
||||
" arbitrary_types_allowed = True\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def input_keys(self) -> List[str]:\n",
|
||||
" \"\"\"Will be whatever keys the prompt expects.\n",
|
||||
"\n",
|
||||
" :meta private:\n",
|
||||
" \"\"\"\n",
|
||||
" return self.prompt.input_variables\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def output_keys(self) -> List[str]:\n",
|
||||
" \"\"\"Will always return text key.\n",
|
||||
"\n",
|
||||
" :meta private:\n",
|
||||
" \"\"\"\n",
|
||||
" return [self.output_key]\n",
|
||||
"\n",
|
||||
" def _call(\n",
|
||||
" self,\n",
|
||||
" inputs: Dict[str, Any],\n",
|
||||
" run_manager: Optional[CallbackManagerForChainRun] = None,\n",
|
||||
" ) -> Dict[str, str]:\n",
|
||||
" # Your custom chain logic goes here\n",
|
||||
" # This is just an example that mimics LLMChain\n",
|
||||
" prompt_value = self.prompt.format_prompt(**inputs)\n",
|
||||
" \n",
|
||||
" # Whenever you call a language model, or another chain, you should pass\n",
|
||||
" # a callback manager to it. This allows the inner run to be tracked by\n",
|
||||
" # any callbacks that are registered on the outer run.\n",
|
||||
" # You can always obtain a callback manager for this by calling\n",
|
||||
" # `run_manager.get_child()` as shown below.\n",
|
||||
" response = self.llm.generate_prompt(\n",
|
||||
" [prompt_value],\n",
|
||||
" callbacks=run_manager.get_child() if run_manager else None\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # If you want to log something about this run, you can do so by calling\n",
|
||||
" # methods on the `run_manager`, as shown below. This will trigger any\n",
|
||||
" # callbacks that are registered for that event.\n",
|
||||
" if run_manager:\n",
|
||||
" run_manager.on_text(\"Log something about this run\")\n",
|
||||
" \n",
|
||||
" return {self.output_key: response.generations[0][0].text}\n",
|
||||
"\n",
|
||||
" async def _acall(\n",
|
||||
" self,\n",
|
||||
" inputs: Dict[str, Any],\n",
|
||||
" run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
|
||||
" ) -> Dict[str, str]:\n",
|
||||
" # Your custom chain logic goes here\n",
|
||||
" # This is just an example that mimics LLMChain\n",
|
||||
" prompt_value = self.prompt.format_prompt(**inputs)\n",
|
||||
" \n",
|
||||
" # Whenever you call a language model, or another chain, you should pass\n",
|
||||
" # a callback manager to it. This allows the inner run to be tracked by\n",
|
||||
" # any callbacks that are registered on the outer run.\n",
|
||||
" # You can always obtain a callback manager for this by calling\n",
|
||||
" # `run_manager.get_child()` as shown below.\n",
|
||||
" response = await self.llm.agenerate_prompt(\n",
|
||||
" [prompt_value],\n",
|
||||
" callbacks=run_manager.get_child() if run_manager else None\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # If you want to log something about this run, you can do so by calling\n",
|
||||
" # methods on the `run_manager`, as shown below. This will trigger any\n",
|
||||
" # callbacks that are registered for that event.\n",
|
||||
" if run_manager:\n",
|
||||
" await run_manager.on_text(\"Log something about this run\")\n",
|
||||
" \n",
|
||||
" return {self.output_key: response.generations[0][0].text}\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def _chain_type(self) -> str:\n",
|
||||
" return \"my_custom_chain\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "18361f89",
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
|
||||
"Log something about this run\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
|
||||
"from langchain.chat_models.openai import ChatOpenAI\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chain = MyCustomChain(\n",
|
||||
" prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n",
|
||||
" llm=ChatOpenAI()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -589,6 +589,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.llm import LLMChain\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
@@ -596,7 +597,7 @@
|
||||
"# Construct a ConversationalRetrievalChain with a streaming llm for combine docs\n",
|
||||
"# and a separate, non-streaming llm for question generation\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"streaming_llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"\n",
|
||||
"question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n",
|
||||
"doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n",
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bda1f3f5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Arxiv\n",
|
||||
"\n",
|
||||
"[arXiv](https://arxiv.org/) is an open-access archive for 2 million scholarly articles in the fields of physics, mathematics, computer science, quantitative biology, quantitative finance, statistics, electrical engineering and systems science, and economics.\n",
|
||||
"\n",
|
||||
"This notebook shows how to load scientific articles from `Arxiv.org` into a document format that we can use downstream."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1b7a1eef-7bf7-4e7d-8bfc-c4e27c9488cb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2abd5578-aa3d-46b9-99af-8b262f0b3df8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, you need to install `arxiv` python package."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b674aaea-ed3a-4541-8414-260a8f67f623",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install arxiv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "094b5f13-7e54-4354-9d83-26d6926ecaa0",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"Second, you need to install `PyMuPDF` python package which transform PDF files from the `arxiv.org` site into the text fromat."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7cd91121-2e96-43ba-af50-319853695f86",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install pymupdf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "95f05e1c-195e-4e2b-ae8e-8d6637f15be6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e29b954c-1407-4797-ae21-6ba8937156be",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`ArxivLoader` has these arguments:\n",
|
||||
"- `query`: free text which used to find documents in the Arxiv\n",
|
||||
"- optional `load_max_docs`: default=100. Use it to limit number of downloaded documents. It takes time to download all 100 documents, so use a small number for experiments.\n",
|
||||
"- optional `load_all_available_meta`: default=False. By defaul only the most important fields downloaded: `Published` (date when document was published/last updated), `Title`, `Authors`, `Summary`. If True, other fields also downloaded."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9bfd5e46",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders.base import Document\n",
|
||||
"from langchain.document_loaders import ArxivLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "700e4ef2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = ArxivLoader(query=\"1605.08386\", load_max_docs=2).load()\n",
|
||||
"len(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "8977bac0-0042-4f23-9754-247dbd32439b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'Published': '2016-05-26',\n",
|
||||
" 'Title': 'Heat-bath random walks with Markov bases',\n",
|
||||
" 'Authors': 'Caprice Stanley, Tobias Windisch',\n",
|
||||
" 'Summary': 'Graphs on lattice points are studied whose edges come from a finite set of\\nallowed moves of arbitrary length. We show that the diameter of these graphs on\\nfibers of a fixed integer matrix can be bounded from above by a constant. We\\nthen study the mixing behaviour of heat-bath random walks on these graphs. We\\nalso state explicit conditions on the set of moves so that the heat-bath random\\nwalk, a generalization of the Glauber dynamics, is an expander in fixed\\ndimension.'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"doc[0].metadata # meta-information of the Document"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "46969806-45a9-4c4d-a61b-cfb9658fc9de",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'arXiv:1605.08386v1 [math.CO] 26 May 2016\\nHEAT-BATH RANDOM WALKS WITH MARKOV BASES\\nCAPRICE STANLEY AND TOBIAS WINDISCH\\nAbstract. Graphs on lattice points are studied whose edges come from a finite set of\\nallowed moves of arbitrary length. We show that the diameter of these graphs on fibers of a\\nfixed integer matrix can be bounded from above by a constant. We then study the mixing\\nbehaviour of heat-b'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"doc[0].page_content[:400] # all pages of the Document content\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,330 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "13afcae7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Self-querying retriever\n",
|
||||
"In the notebook we'll demo the `SelfQueryRetriever`, which, as the name suggests, has the ability to query itself. Specifically, given any natural language query, the retriever uses a query-constructing LLM chain to write a structured query and then applies that structured query to it's underlying VectorStore. This allows the retriever to not only use the user-input query for semantic similarity comparison with the contents of stored documented, but to also extract filters from the user query on the metadata of stored documents and to execute those filter."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "68e75fb9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating a Pinecone index\n",
|
||||
"First we'll want to create a Pinecone VectorStore and seed it with some data. We've created a small demo set of documents that contain summaries of movies.\n",
|
||||
"\n",
|
||||
"NOTE: The self-query retriever currently only has built-in support for Pinecone VectorStore.\n",
|
||||
"\n",
|
||||
"NOTE: The self-query retriever requires you to have `lark` installed (`pip install lark`)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "63a8af5b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install lark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "3eb9c9a4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pinecone/index.py:4: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
|
||||
" from tqdm.autonotebook import tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pinecone\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"pinecone.init(api_key=os.environ[\"PINECONE_API_KEY\"], environment=os.environ[\"PINECONE_ENV\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "cb4a5787",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import Pinecone\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"# create new index\n",
|
||||
"pinecone.create_index(\"langchain-self-retriever-demo\", dimension=1536)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "bcbe04d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = [\n",
|
||||
" Document(page_content=\"A bunch of scientists bring back dinosaurs and mayhem breaks loose\", metadata={\"year\": 1993, \"rating\": 7.7, \"genre\": [\"action\", \"science fiction\"]}),\n",
|
||||
" Document(page_content=\"Leo DiCaprio gets lost in a dream within a dream within a dream within a ...\", metadata={\"year\": 2010, \"director\": \"Christopher Nolan\", \"rating\": 8.2}),\n",
|
||||
" Document(page_content=\"A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea\", metadata={\"year\": 2006, \"director\": \"Satoshi Kon\", \"rating\": 8.6}),\n",
|
||||
" Document(page_content=\"A bunch of normal-sized women are supremely wholesome and some men pine after them\", metadata={\"year\": 2019, \"director\": \"Greta Gerwig\", \"rating\": 8.3}),\n",
|
||||
" Document(page_content=\"Toys come alive and have a blast doing so\", metadata={\"year\": 1995, \"genre\": \"animated\"}),\n",
|
||||
" Document(page_content=\"Three men walk into the Zone, three men walk out of the Zone\", metadata={\"year\": 1979, \"rating\": 9.9, \"director\": \"Andrei Tarkovsky\", \"genre\": [\"science fiction\", \"thriller\"], \"rating\": 9.9})\n",
|
||||
"]\n",
|
||||
"vectorstore = Pinecone.from_documents(\n",
|
||||
" docs, embeddings, index_name=\"langchain-self-retriever-demo\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5ecaab6d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Creating our self-querying retriever\n",
|
||||
"Now we can instantiate our retriever. To do this we'll need to provide some information upfront about the metadata fields that our documents support and a short description of the document contents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "86e34dbf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.retrievers.self_query.base import SelfQueryRetriever\n",
|
||||
"from langchain.chains.query_constructor.base import AttributeInfo\n",
|
||||
"\n",
|
||||
"metadata_field_info=[\n",
|
||||
" AttributeInfo(\n",
|
||||
" name=\"genre\",\n",
|
||||
" description=\"The genre of the movie\", \n",
|
||||
" type=\"string or list[string]\", \n",
|
||||
" ),\n",
|
||||
" AttributeInfo(\n",
|
||||
" name=\"year\",\n",
|
||||
" description=\"The year the movie was released\", \n",
|
||||
" type=\"integer\", \n",
|
||||
" ),\n",
|
||||
" AttributeInfo(\n",
|
||||
" name=\"director\",\n",
|
||||
" description=\"The name of the movie director\", \n",
|
||||
" type=\"string\", \n",
|
||||
" ),\n",
|
||||
" AttributeInfo(\n",
|
||||
" name=\"rating\",\n",
|
||||
" description=\"A 1-10 rating for the movie\",\n",
|
||||
" type=\"float\"\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"document_content_description = \"Brief summary of a movie\"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"retriever = SelfQueryRetriever.from_llm(llm, vectorstore, document_content_description, metadata_field_info, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ea9df8d4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Testing it out\n",
|
||||
"And now we can try actually using our retriever!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "38a126e9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='dinosaur' filter=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'genre': ['action', 'science fiction'], 'rating': 7.7, 'year': 1993.0}),\n",
|
||||
" Document(page_content='Toys come alive and have a blast doing so', metadata={'genre': 'animated', 'year': 1995.0}),\n",
|
||||
" Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'director': 'Satoshi Kon', 'rating': 8.6, 'year': 2006.0}),\n",
|
||||
" Document(page_content='Leo DiCaprio gets lost in a dream within a dream within a dream within a ...', metadata={'director': 'Christopher Nolan', 'rating': 8.2, 'year': 2010.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This example only specifies a relevant query\n",
|
||||
"retriever.get_relevant_documents(\"What are some movies about dinosaurs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "fc3f1e6e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'director': 'Satoshi Kon', 'rating': 8.6, 'year': 2006.0}),\n",
|
||||
" Document(page_content='Three men walk into the Zone, three men walk out of the Zone', metadata={'director': 'Andrei Tarkovsky', 'genre': ['science fiction', 'thriller'], 'rating': 9.9, 'year': 1979.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This example only specifies a filter\n",
|
||||
"retriever.get_relevant_documents(\"I want to watch a movie rated higher than 8.5\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "b19d4da0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='A bunch of normal-sized women are supremely wholesome and some men pine after them', metadata={'director': 'Greta Gerwig', 'rating': 8.3, 'year': 2019.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This example specifies a query and a filter\n",
|
||||
"retriever.get_relevant_documents(\"Has Greta Gerwig directed any movies about women\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "f900e40e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction'), Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Three men walk into the Zone, three men walk out of the Zone', metadata={'director': 'Andrei Tarkovsky', 'genre': ['science fiction', 'thriller'], 'rating': 9.9, 'year': 1979.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This example specifies a composite filter\n",
|
||||
"retriever.get_relevant_documents(\"What's a highly rated (above 8.5) science fiction film?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "12a51522",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990.0), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005.0), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Toys come alive and have a blast doing so', metadata={'genre': 'animated', 'year': 1995.0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This example specifies a query and composite filter\n",
|
||||
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "69bbd809",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,179 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "683953b3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LanceDB\n",
|
||||
"\n",
|
||||
"This notebook shows how to use functionality related to the LanceDB vector database based on the Lance data format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "bfcf346a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install lancedb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "aac9563e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import LanceDB"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"loader = TextLoader('../../../state_of_the_union.txt')\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"documents = CharacterTextSplitter().split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "6e104aee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import lancedb\n",
|
||||
"\n",
|
||||
"db = lancedb.connect('/tmp/lancedb')\n",
|
||||
"table = db.create_table(\"my_table\", data=[\n",
|
||||
" {\"vector\": embeddings.embed_query(\"Hello World\"), \"text\": \"Hello World\", \"id\": \"1\"}\n",
|
||||
"], mode=\"overwrite\")\n",
|
||||
"\n",
|
||||
"docsearch = LanceDB.from_documents(documents, embeddings, connection=table)\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = docsearch.similarity_search(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "9c608226",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n",
|
||||
"\n",
|
||||
"Officer Mora was 27 years old. \n",
|
||||
"\n",
|
||||
"Officer Rivera was 22. \n",
|
||||
"\n",
|
||||
"Both Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n",
|
||||
"\n",
|
||||
"I spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. \n",
|
||||
"\n",
|
||||
"I’ve worked on these issues a long time. \n",
|
||||
"\n",
|
||||
"I know what works: Investing in crime preventionand community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety. \n",
|
||||
"\n",
|
||||
"So let’s not abandon our streets. Or choose between safety and equal justice. \n",
|
||||
"\n",
|
||||
"Let’s come together to protect our communities, restore trust, and hold law enforcement accountable. \n",
|
||||
"\n",
|
||||
"That’s why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers. \n",
|
||||
"\n",
|
||||
"That’s why the American Rescue Plan provided $350 Billion that cities, states, and counties can use to hire more police and invest in proven strategies like community violence interruption—trusted messengers breaking the cycle of violence and trauma and giving young people hope. \n",
|
||||
"\n",
|
||||
"We should all agree: The answer is not to Defund the police. The answer is to FUND the police with the resources and training they need to protect our communities. \n",
|
||||
"\n",
|
||||
"I ask Democrats and Republicans alike: Pass my budget and keep our neighborhoods safe. \n",
|
||||
"\n",
|
||||
"And I will keep doing everything in my power to crack down on gun trafficking and ghost guns you can buy online and make at home—they have no serial numbers and can’t be traced. \n",
|
||||
"\n",
|
||||
"And I ask Congress to pass proven measures to reduce gun violence. Pass universal background checks. Why should anyone on a terrorist list be able to purchase a weapon? \n",
|
||||
"\n",
|
||||
"Ban assault weapons and high-capacity magazines. \n",
|
||||
"\n",
|
||||
"Repeal the liability shield that makes gun manufacturers the only industry in America that can’t be sued. \n",
|
||||
"\n",
|
||||
"These laws don’t infringe on the Second Amendment. They save lives. \n",
|
||||
"\n",
|
||||
"The most fundamental right in America is the right to vote – and to have it counted. And it’s under assault. \n",
|
||||
"\n",
|
||||
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
|
||||
"\n",
|
||||
"We cannot let this happen. \n",
|
||||
"\n",
|
||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||
"\n",
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||
"\n",
|
||||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. \n",
|
||||
"\n",
|
||||
"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n",
|
||||
"\n",
|
||||
"And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n",
|
||||
"\n",
|
||||
"We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling. \n",
|
||||
"\n",
|
||||
"We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers. \n",
|
||||
"\n",
|
||||
"We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a359ed74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -75,8 +75,7 @@
|
||||
" \"vectorizer\": \"text2vec-openai\",\n",
|
||||
" \"moduleConfig\": {\n",
|
||||
" \"text2vec-openai\": {\n",
|
||||
" \"model\": \"ada\",\n",
|
||||
" \"modelVersion\": \"002\",\n",
|
||||
" \"model\": \"babbage\",\n",
|
||||
" \"type\": \"text\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
|
||||
@@ -80,8 +80,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -373,8 +373,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])\n"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -785,9 +785,7 @@
|
||||
"id": "9df0dab8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!rm .langchain.db sqlite.db"
|
||||
]
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -21,13 +21,14 @@
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI, Anthropic\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain.schema import HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"id": "77f60a4b-f786-41f2-972e-e5bb8a48dcd5",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -78,7 +79,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"resp = llm(\"Write me a song about sparkling water.\")"
|
||||
]
|
||||
},
|
||||
@@ -94,7 +95,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"id": "a35373f1-9ee6-4753-a343-5aee749b8527",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -135,7 +136,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"id": "22665f16-e05b-473c-a4bd-ad75744ea024",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -190,7 +191,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])"
|
||||
]
|
||||
},
|
||||
@@ -204,7 +205,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"id": "eadae4ba-9f21-4ec8-845d-dd43b0edc2dc",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -244,7 +245,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = Anthropic(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"llm = Anthropic(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"llm(\"Write me a song about sparkling water.\")"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
"source": [
|
||||
"from langchain import PromptTemplate, LLMChain\n",
|
||||
"from langchain.llms import GPT4All\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||
]
|
||||
},
|
||||
@@ -123,9 +124,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Callbacks support token-wise streaming\n",
|
||||
"callbacks = [StreamingStdOutCallbackHandler()]\n",
|
||||
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
|
||||
"# Verbose is required to pass to the callback manager\n",
|
||||
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)"
|
||||
"llm = GPT4All(model=local_path, callback_manager=callback_manager, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PipelineAI\n",
|
||||
"\n",
|
||||
"PipelineAI allows you to run your ML models at scale in the cloud. It also provides API access to [several LLM models](https://pipeline.ai).\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use Langchain with [PipelineAI](https://docs.pipeline.ai/docs)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install pipeline-ai\n",
|
||||
"The `pipeline-ai` library is required to use the `PipelineAI` API, AKA `Pipeline Cloud`. Install `pipeline-ai` using `pip install pipeline-ai`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install the package\n",
|
||||
"!pip install pipeline-ai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from langchain.llms import PipelineAI\n",
|
||||
"from langchain import PromptTemplate, LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set the Environment API Key\n",
|
||||
"Make sure to get your API key from PipelineAI. Check out the [cloud quickstart guide](https://docs.pipeline.ai/docs/cloud-quickstart). You'll be given a 30 day free trial with 10 hours of serverless GPU compute to test different models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"PIPELINE_API_KEY\"] = \"YOUR_API_KEY_HERE\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create the PipelineAI instance\n",
|
||||
"When instantiating PipelineAI, you need to specify the id or tag of the pipeline you want to use, e.g. `pipeline_key = \"public/gpt-j:base\"`. You then have the option of passing additional pipeline-specific keyword arguments:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = PipelineAI(pipeline_key=\"YOUR_PIPELINE_KEY\", pipeline_kwargs={...})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create a Prompt Template\n",
|
||||
"We will create a prompt template for Question and Answer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initiate the LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run the LLMChain\n",
|
||||
"Provide a question and run the LLMChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -17,9 +17,7 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ad0b5edf",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Feast\n",
|
||||
"\n",
|
||||
@@ -213,241 +211,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c4049990-651d-44d3-82b1-0cd122da55c1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tecton\n",
|
||||
"\n",
|
||||
"Above, we showed how you could use Feast, a popular open source and self-managed feature store, with LangChain. Our examples below will show a similar integration using Tecton. Tecton is a fully managed feature platform built to orchestrate the complete ML feature lifecycle, from transformation to online serving, with enterprise-grade SLAs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7bb4dba1-0678-4ea4-be0a-d353c0b13fc2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"### Prerequisites\n",
|
||||
"\n",
|
||||
"* Tecton Deployment (sign up at [https://tecton.ai](https://tecton.ai))\n",
|
||||
"* `TECTON_API_KEY` environment variable set to a valid Service Account key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ac9eb618-8c52-4cd6-bb8e-9c99a150dfa6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"### Define and Load Features\n",
|
||||
"\n",
|
||||
"We will use the user_transaction_counts Feature View from the [Tecton tutorial](https://docs.tecton.ai/docs/tutorials/tecton-fundamentals) as part of a Feature Service. For simplicity, we are only using a single Feature View; however, more sophisticated applications may require more feature views to retrieve the features needed for its prompt.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"user_transaction_metrics = FeatureService(\n",
|
||||
" name = \"user_transaction_metrics\",\n",
|
||||
" features = [user_transaction_counts]\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The above Feature Service is expected to be [applied to a live workspace](https://docs.tecton.ai/docs/applying-feature-repository-changes-to-a-workspace). For this example, we will be using the \"prod\" workspace."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"id": "32e9675d-a7e5-429f-906f-2260294d3e46",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tecton\n",
|
||||
"\n",
|
||||
"workspace = tecton.get_workspace(\"prod\")\n",
|
||||
"feature_service = workspace.get_feature_service(\"user_transaction_metrics\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29b7550c-0eb4-4bd1-a501-1c63fb77aa56",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Prompts\n",
|
||||
"\n",
|
||||
"Here we will set up a custom TectonPromptTemplate. This prompt template will take in a user_id , look up their stats, and format those stats into a prompt.\n",
|
||||
"\n",
|
||||
"Note that the input to this prompt template is just `user_id`, since that is the only user defined piece (all other variables are looked up inside the prompt template)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"id": "6fb77ea4-64c6-4e48-a783-bd1ece021b82",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate, StringPromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 77,
|
||||
"id": "02a98fbc-8135-4b11-bf60-85d28e426667",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"Given the vendor's up to date transaction stats, write them a note based on the following rules:\n",
|
||||
"\n",
|
||||
"1. If they had a transaction in the last day, write a short congratulations message on their recent sales\n",
|
||||
"2. If no transaction in the last day, but they had a transaction in the last 30 days, playfully encourage them to sell more.\n",
|
||||
"3. Always add a silly joke about chickens at the end\n",
|
||||
"\n",
|
||||
"Here are the vendor's stats:\n",
|
||||
"Number of Transactions Last Day: {transaction_count_1d}\n",
|
||||
"Number of Transactions Last 30 Days: {transaction_count_30d}\n",
|
||||
"\n",
|
||||
"Your response:\"\"\"\n",
|
||||
"prompt = PromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 78,
|
||||
"id": "a35cdfd5-6ccc-4394-acfe-60d53804be51",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TectonPromptTemplate(StringPromptTemplate):\n",
|
||||
" \n",
|
||||
" def format(self, **kwargs) -> str:\n",
|
||||
" user_id = kwargs.pop(\"user_id\")\n",
|
||||
" feature_vector = feature_service.get_online_features(join_keys={\"user_id\": user_id}).to_dict()\n",
|
||||
" kwargs[\"transaction_count_1d\"] = feature_vector[\"user_transaction_counts.transaction_count_1d_1d\"]\n",
|
||||
" kwargs[\"transaction_count_30d\"] = feature_vector[\"user_transaction_counts.transaction_count_30d_1d\"]\n",
|
||||
" return prompt.format(**kwargs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 79,
|
||||
"id": "d5915df0-fb16-4770-8a82-22f885b74d1a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt_template = TectonPromptTemplate(input_variables=[\"user_id\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 80,
|
||||
"id": "a36abfc8-ea60-4ae0-a36d-d7b639c7307c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Given the vendor's up to date transaction stats, write them a note based on the following rules:\n",
|
||||
"\n",
|
||||
"1. If they had a transaction in the last day, write a short congratulations message on their recent sales\n",
|
||||
"2. If no transaction in the last day, but they had a transaction in the last 30 days, playfully encourage them to sell more.\n",
|
||||
"3. Always add a silly joke about chickens at the end\n",
|
||||
"\n",
|
||||
"Here are the vendor's stats:\n",
|
||||
"Number of Transactions Last Day: 657\n",
|
||||
"Number of Transactions Last 30 Days: 20326\n",
|
||||
"\n",
|
||||
"Your response:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(prompt_template.format(user_id=\"user_469998441571\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f8d4b905-1051-4303-9c33-8eddb65c1274",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"### Use in a chain\n",
|
||||
"\n",
|
||||
"We can now use this in a chain, successfully creating a chain that achieves personalization backed by the Tecton Feature Platform"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 81,
|
||||
"id": "ffb60cd0-8e3c-4c9d-b639-43d766e12c4c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.chains import LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 82,
|
||||
"id": "3918abc7-00b5-466f-bdfc-ab046cd282da",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = LLMChain(llm=ChatOpenAI(), prompt=prompt_template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 83,
|
||||
"id": "e7d91c4b-3e99-40cc-b3e9-a004c8c9193e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Wow, congratulations on your recent sales! Your business is really soaring like a chicken on a hot air balloon! Keep up the great work!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 83,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(\"user_469998441571\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f752b924-caf9-4f7a-b78b-cb8c8ada8c2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -466,7 +229,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,15 +6,16 @@ First, you should install tracing and set up your environment properly.
|
||||
You can use either a locally hosted version of this (uses Docker) or a cloud hosted version (in closed alpha).
|
||||
If you're interested in using the hosted platform, please fill out the form [here](https://forms.gle/tRCEMSeopZf6TE3b6).
|
||||
|
||||
|
||||
- [Locally Hosted Setup](./tracing/local_installation.md)
|
||||
- [Cloud Hosted Setup](./tracing/hosted_installation.md)
|
||||
|
||||
## Tracing Walkthrough
|
||||
|
||||
When you first access the UI, you should see a page with your tracing sessions.
|
||||
An initial one "default" should already be created for you.
|
||||
A session is just a way to group traces together.
|
||||
If you click on a session, it will take you to a page with no recorded traces that says "No Runs."
|
||||
When you first access the UI, you should see a page with your tracing sessions.
|
||||
An initial one "default" should already be created for you.
|
||||
A session is just a way to group traces together.
|
||||
If you click on a session, it will take you to a page with no recorded traces that says "No Runs."
|
||||
You can create a new session with the new session form.
|
||||
|
||||

|
||||
@@ -34,7 +35,7 @@ We can keep on clicking further and further down to explore deeper and deeper.
|
||||
|
||||

|
||||
|
||||
We can also click on the "Explore" button of the top level run to dive even deeper.
|
||||
We can also click on the "Explore" button of the top level run to dive even deeper.
|
||||
Here, we can see the inputs and outputs in full, as well as all the nested traces.
|
||||
|
||||

|
||||
@@ -45,12 +46,11 @@ For example, here is the lowest level trace with the exact inputs/outputs to the
|
||||

|
||||
|
||||
## Changing Sessions
|
||||
|
||||
1. To initially record traces to a session other than `"default"`, you can set the `LANGCHAIN_SESSION` environment variable to the name of the session you want to record to:
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["LANGCHAIN_TRACING"] = "true"
|
||||
os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
os.environ["LANGCHAIN_SESSION"] = "my_session" # Make sure this session actually exists. You can create a new session in the UI.
|
||||
```
|
||||
|
||||
|
||||
@@ -14,6 +14,4 @@ Specific implementations of agent simulations (or parts of agent simulations) in
|
||||
|
||||
## Simulations with Multiple Agents
|
||||
- [Multi-Player D&D](agent_simulations/multi_player_dnd.ipynb): an example of how to use a generic dialogue simulator for multiple dialogue agents with a custom speaker-ordering, illustrated with a variant of the popular Dungeons & Dragons role playing game.
|
||||
- [Decentralized Speaker Selection](agent_simulations/multiagent_bidding.ipynb): an example of how to implement a multi-agent dialogue without a fixed schedule for who speaks when. Instead the agents decide for themselves who speaks by outputting bids to speak. This example shows how to do this in the context of a fictitious presidential debate.
|
||||
- [Authoritarian Speaker Selection](agent_simulations/multiagent_authoritarian.ipynb): an example of how to implement a multi-agent dialogue, where a privileged agent directs who speaks what. This example also showcases how to enable the privileged agent to determine when the conversation terminates. This example shows how to do this in the context of a fictitious news show.
|
||||
- [Generative Agents](agent_simulations/characters.ipynb): This notebook implements a generative agent based on the paper [Generative Agents: Interactive Simulacra of Human Behavior](https://arxiv.org/abs/2304.03442) by Park, et. al.
|
||||
|
||||
@@ -1,849 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Multi-agent authoritarian speaker selection\n",
|
||||
"\n",
|
||||
"This notebook showcases how to implement a multi-agent simulation where a privileged agent decides who to speak.\n",
|
||||
"This follows the polar opposite selection scheme as [multi-agent decentralized speaker selection](https://python.langchain.com/en/latest/use_cases/agent_simulations/multiagent_bidding.html).\n",
|
||||
"\n",
|
||||
"We show an example of this approach in the context of a fictitious simulation of a news network. This example will showcase how we can implement agents that\n",
|
||||
"- think before speaking\n",
|
||||
"- terminate the conversation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Import LangChain related modules "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import OrderedDict\n",
|
||||
"import functools\n",
|
||||
"import random\n",
|
||||
"import re\n",
|
||||
"import tenacity\n",
|
||||
"from typing import List, Dict, Callable\n",
|
||||
"\n",
|
||||
"from langchain.prompts import (\n",
|
||||
" ChatPromptTemplate, \n",
|
||||
" HumanMessagePromptTemplate,\n",
|
||||
" PromptTemplate\n",
|
||||
")\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.output_parsers import RegexParser\n",
|
||||
"from langchain.schema import (\n",
|
||||
" AIMessage,\n",
|
||||
" HumanMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" BaseMessage,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `DialogueAgent` and `DialogueSimulator` classes\n",
|
||||
"We will use the same `DialogueAgent` and `DialogueSimulator` classes defined in our other examples [Multi-Player Dungeons & Dragons](https://python.langchain.com/en/latest/use_cases/agent_simulations/multi_player_dnd.html) and [Decentralized Speaker Selection](https://python.langchain.com/en/latest/use_cases/agent_simulations/multiagent_bidding.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DialogueAgent:\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" name: str,\n",
|
||||
" system_message: SystemMessage,\n",
|
||||
" model: ChatOpenAI,\n",
|
||||
" ) -> None:\n",
|
||||
" self.name = name\n",
|
||||
" self.system_message = system_message\n",
|
||||
" self.model = model\n",
|
||||
" self.prefix = f\"{self.name}: \"\n",
|
||||
" self.reset()\n",
|
||||
" \n",
|
||||
" def reset(self):\n",
|
||||
" self.message_history = [\"Here is the conversation so far.\"]\n",
|
||||
"\n",
|
||||
" def send(self) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Applies the chatmodel to the message history\n",
|
||||
" and returns the message string\n",
|
||||
" \"\"\"\n",
|
||||
" message = self.model(\n",
|
||||
" [\n",
|
||||
" self.system_message,\n",
|
||||
" HumanMessage(content=\"\\n\".join(self.message_history + [self.prefix])),\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" return message.content\n",
|
||||
"\n",
|
||||
" def receive(self, name: str, message: str) -> None:\n",
|
||||
" \"\"\"\n",
|
||||
" Concatenates {message} spoken by {name} into message history\n",
|
||||
" \"\"\"\n",
|
||||
" self.message_history.append(f\"{name}: {message}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class DialogueSimulator:\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" agents: List[DialogueAgent],\n",
|
||||
" selection_function: Callable[[int, List[DialogueAgent]], int],\n",
|
||||
" ) -> None:\n",
|
||||
" self.agents = agents\n",
|
||||
" self._step = 0\n",
|
||||
" self.select_next_speaker = selection_function\n",
|
||||
" \n",
|
||||
" def reset(self):\n",
|
||||
" for agent in self.agents:\n",
|
||||
" agent.reset()\n",
|
||||
"\n",
|
||||
" def inject(self, name: str, message: str):\n",
|
||||
" \"\"\"\n",
|
||||
" Initiates the conversation with a {message} from {name}\n",
|
||||
" \"\"\"\n",
|
||||
" for agent in self.agents:\n",
|
||||
" agent.receive(name, message)\n",
|
||||
"\n",
|
||||
" # increment time\n",
|
||||
" self._step += 1\n",
|
||||
"\n",
|
||||
" def step(self) -> tuple[str, str]:\n",
|
||||
" # 1. choose the next speaker\n",
|
||||
" speaker_idx = self.select_next_speaker(self._step, self.agents)\n",
|
||||
" speaker = self.agents[speaker_idx]\n",
|
||||
"\n",
|
||||
" # 2. next speaker sends message\n",
|
||||
" message = speaker.send()\n",
|
||||
"\n",
|
||||
" # 3. everyone receives message\n",
|
||||
" for receiver in self.agents:\n",
|
||||
" receiver.receive(speaker.name, message)\n",
|
||||
"\n",
|
||||
" # 4. increment time\n",
|
||||
" self._step += 1\n",
|
||||
"\n",
|
||||
" return speaker.name, message"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `DirectorDialogueAgent` class\n",
|
||||
"The `DirectorDialogueAgent` is a privileged agent that chooses which of the other agents to speak next. This agent is responsible for\n",
|
||||
"1. steering the conversation by choosing which agent speaks when\n",
|
||||
"2. terminating the conversation.\n",
|
||||
"\n",
|
||||
"In order to implement such an agent, we need to solve several problems.\n",
|
||||
"\n",
|
||||
"First, to steer the conversation, the `DirectorDialogueAgent` needs to (1) reflect on what has been said, (2) choose the next agent, and (3) prompt the next agent to speak, all in a single message. While it may be possible to prompt an LLM to perform all three steps in the same call, this requires writing custom code to parse the outputted message to extract which next agent is chosen to speak. This is less reliable the LLM can express how it chooses the next agent in different ways.\n",
|
||||
"\n",
|
||||
"What we can do instead is to explicitly break steps (1-3) into three separate LLM calls. First we will ask the `DirectorDialogueAgent` to reflect on the conversation so far and generate a response. Then we prompt the `DirectorDialogueAgent` to output the index of the next agent, which is easily parseable. Lastly, we pass the name of the selected next agent back to `DirectorDialogueAgent` to ask it prompt the next agent to speak. \n",
|
||||
"\n",
|
||||
"Second, simply prompting the `DirectorDialogueAgent` to decide when to terminate the conversation often results in the `DirectorDialogueAgent` terminating the conversation immediately. To fix this problem, we randomly sample a Bernoulli variable to decide whether the conversation should terminate. Depending on the value of this variable, we will inject a custom prompt to tell the `DirectorDialogueAgent` to either continue the conversation or terminate the conversation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class IntegerOutputParser(RegexParser):\n",
|
||||
" def get_format_instructions(self) -> str:\n",
|
||||
" return 'Your response should be an integer delimited by angled brackets, like this: <int>.' \n",
|
||||
"\n",
|
||||
"class DirectorDialogueAgent(DialogueAgent):\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" name,\n",
|
||||
" system_message: SystemMessage,\n",
|
||||
" model: ChatOpenAI,\n",
|
||||
" speakers: List[DialogueAgent],\n",
|
||||
" stopping_probability: float,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__(name, system_message, model)\n",
|
||||
" self.speakers = speakers\n",
|
||||
" self.next_speaker = ''\n",
|
||||
" \n",
|
||||
" self.stop = False\n",
|
||||
" self.stopping_probability = stopping_probability\n",
|
||||
" self.termination_clause = 'Finish the conversation by stating a concluding message and thanking everyone.'\n",
|
||||
" self.continuation_clause = 'Do not end the conversation. Keep the conversation going by adding your own ideas.'\n",
|
||||
" \n",
|
||||
" # 1. have a prompt for generating a response to the previous speaker\n",
|
||||
" self.response_prompt_template = PromptTemplate(\n",
|
||||
" input_variables=[\"message_history\", \"termination_clause\"],\n",
|
||||
" template=f\"\"\"{{message_history}}\n",
|
||||
"\n",
|
||||
"Follow up with an insightful comment.\n",
|
||||
"{{termination_clause}}\n",
|
||||
"{self.prefix}\n",
|
||||
" \"\"\")\n",
|
||||
" \n",
|
||||
" # 2. have a prompt for deciding who to speak next\n",
|
||||
" self.choice_parser = IntegerOutputParser(\n",
|
||||
" regex=r'<(\\d+)>', \n",
|
||||
" output_keys=['choice'], \n",
|
||||
" default_output_key='choice') \n",
|
||||
" self.choose_next_speaker_prompt_template = PromptTemplate(\n",
|
||||
" input_variables=[\"message_history\", \"speaker_names\"],\n",
|
||||
" template=f\"\"\"{{message_history}}\n",
|
||||
"\n",
|
||||
"Given the above conversation, select the next speaker by choosing index next to their name: \n",
|
||||
"{{speaker_names}}\n",
|
||||
"\n",
|
||||
"{self.choice_parser.get_format_instructions()}\n",
|
||||
"\n",
|
||||
"Do nothing else.\n",
|
||||
" \"\"\")\n",
|
||||
" \n",
|
||||
" # 3. have a prompt for prompting the next speaker to speak\n",
|
||||
" self.prompt_next_speaker_prompt_template = PromptTemplate(\n",
|
||||
" input_variables=[\"message_history\", \"next_speaker\"],\n",
|
||||
" template=f\"\"\"{{message_history}}\n",
|
||||
"\n",
|
||||
"The next speaker is {{next_speaker}}. \n",
|
||||
"Prompt the next speaker to speak with an insightful question.\n",
|
||||
"{self.prefix}\n",
|
||||
" \"\"\")\n",
|
||||
" \n",
|
||||
" def _generate_response(self):\n",
|
||||
" # if self.stop = True, then we will inject the prompt with a termination clause\n",
|
||||
" sample = random.uniform(0,1)\n",
|
||||
" self.stop = sample < self.stopping_probability\n",
|
||||
" \n",
|
||||
" print(f'\\tStop? {self.stop}\\n')\n",
|
||||
" \n",
|
||||
" response_prompt = self.response_prompt_template.format(\n",
|
||||
" message_history='\\n'.join(self.message_history),\n",
|
||||
" termination_clause=self.termination_clause if self.stop else ''\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" self.response = self.model(\n",
|
||||
" [\n",
|
||||
" self.system_message,\n",
|
||||
" HumanMessage(content=response_prompt),\n",
|
||||
" ]\n",
|
||||
" ).content\n",
|
||||
" \n",
|
||||
" return self.response\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" @tenacity.retry(stop=tenacity.stop_after_attempt(2),\n",
|
||||
" wait=tenacity.wait_none(), # No waiting time between retries\n",
|
||||
" retry=tenacity.retry_if_exception_type(ValueError),\n",
|
||||
" before_sleep=lambda retry_state: print(f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"),\n",
|
||||
" retry_error_callback=lambda retry_state: 0) # Default value when all retries are exhausted\n",
|
||||
" def _choose_next_speaker(self) -> str: \n",
|
||||
" speaker_names = '\\n'.join([f'{idx}: {name}' for idx, name in enumerate(self.speakers)])\n",
|
||||
" choice_prompt = self.choose_next_speaker_prompt_template.format(\n",
|
||||
" message_history='\\n'.join(self.message_history + [self.prefix] + [self.response]),\n",
|
||||
" speaker_names=speaker_names\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" choice_string = self.model(\n",
|
||||
" [\n",
|
||||
" self.system_message,\n",
|
||||
" HumanMessage(content=choice_prompt),\n",
|
||||
" ]\n",
|
||||
" ).content\n",
|
||||
" choice = int(self.choice_parser.parse(choice_string)['choice'])\n",
|
||||
" \n",
|
||||
" return choice\n",
|
||||
" \n",
|
||||
" def select_next_speaker(self):\n",
|
||||
" return self.chosen_speaker_id\n",
|
||||
" \n",
|
||||
" def send(self) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Applies the chatmodel to the message history\n",
|
||||
" and returns the message string\n",
|
||||
" \"\"\"\n",
|
||||
" # 1. generate and save response to the previous speaker\n",
|
||||
" self.response = self._generate_response()\n",
|
||||
" \n",
|
||||
" if self.stop:\n",
|
||||
" message = self.response \n",
|
||||
" else:\n",
|
||||
" # 2. decide who to speak next\n",
|
||||
" self.chosen_speaker_id = self._choose_next_speaker()\n",
|
||||
" self.next_speaker = self.speakers[self.chosen_speaker_id]\n",
|
||||
" print(f'\\tNext speaker: {self.next_speaker}\\n')\n",
|
||||
"\n",
|
||||
" # 3. prompt the next speaker to speak\n",
|
||||
" next_prompt = self.prompt_next_speaker_prompt_template.format(\n",
|
||||
" message_history=\"\\n\".join(self.message_history + [self.prefix] + [self.response]),\n",
|
||||
" next_speaker=self.next_speaker\n",
|
||||
" )\n",
|
||||
" message = self.model(\n",
|
||||
" [\n",
|
||||
" self.system_message,\n",
|
||||
" HumanMessage(content=next_prompt),\n",
|
||||
" ]\n",
|
||||
" ).content\n",
|
||||
" message = ' '.join([self.response, message])\n",
|
||||
" \n",
|
||||
" return message"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define participants and topic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"topic = \"The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze\"\n",
|
||||
"director_name = \"Jon Stewart\"\n",
|
||||
"agent_summaries = OrderedDict({\n",
|
||||
" \"Jon Stewart\": (\"Host of the Daily Show\", \"New York\"),\n",
|
||||
" \"Samantha Bee\": (\"Hollywood Correspondent\", \"Los Angeles\"), \n",
|
||||
" \"Aasif Mandvi\": (\"CIA Correspondent\", \"Washington D.C.\"),\n",
|
||||
" \"Ronny Chieng\": (\"Average American Correspondent\", \"Cleveland, Ohio\"),\n",
|
||||
"})\n",
|
||||
"word_limit = 50"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generate system messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_summary_string = '\\n- '.join([''] + [f'{name}: {role}, located in {location}' for name, (role, location) in agent_summaries.items()])\n",
|
||||
"\n",
|
||||
"conversation_description = f\"\"\"This is a Daily Show episode discussing the following topic: {topic}.\n",
|
||||
"\n",
|
||||
"The episode features {agent_summary_string}.\"\"\"\n",
|
||||
"\n",
|
||||
"agent_descriptor_system_message = SystemMessage(\n",
|
||||
" content=\"You can add detail to the description of each person.\")\n",
|
||||
"\n",
|
||||
"def generate_agent_description(agent_name, agent_role, agent_location):\n",
|
||||
" agent_specifier_prompt = [\n",
|
||||
" agent_descriptor_system_message,\n",
|
||||
" HumanMessage(content=\n",
|
||||
" f\"\"\"{conversation_description}\n",
|
||||
" Please reply with a creative description of {agent_name}, who is a {agent_role} in {agent_location}, that emphasizes their particular role and location.\n",
|
||||
" Speak directly to {agent_name} in {word_limit} words or less.\n",
|
||||
" Do not add anything else.\"\"\"\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
" agent_description = ChatOpenAI(temperature=1.0)(agent_specifier_prompt).content\n",
|
||||
" return agent_description\n",
|
||||
"\n",
|
||||
"def generate_agent_header(agent_name, agent_role, agent_location, agent_description):\n",
|
||||
" return f\"\"\"{conversation_description}\n",
|
||||
"\n",
|
||||
"Your name is {agent_name}, your role is {agent_role}, and you are located in {agent_location}.\n",
|
||||
"\n",
|
||||
"Your description is as follows: {agent_description}\n",
|
||||
"\n",
|
||||
"You are discussing the topic: {topic}.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"def generate_agent_system_message(agent_name, agent_header):\n",
|
||||
" return SystemMessage(content=(\n",
|
||||
" f\"\"\"{agent_header}\n",
|
||||
"You will speak in the style of {agent_name}, and exaggerate your personality.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of {agent_name}\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of {agent_name}.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to {word_limit} words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \"\"\"\n",
|
||||
" ))\n",
|
||||
"\n",
|
||||
"agent_descriptions = [generate_agent_description(name, role, location) for name, (role, location) in agent_summaries.items()]\n",
|
||||
"agent_headers = [generate_agent_header(name, role, location, description) for (name, (role, location)), description in zip(agent_summaries.items(), agent_descriptions)]\n",
|
||||
"agent_system_messages = [generate_agent_system_message(name, header) for name, header in zip(agent_summaries, agent_headers)]\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Jon Stewart Description:\n",
|
||||
"\n",
|
||||
"Jon Stewart, the sharp-tongued and quick-witted host of the Daily Show, holding it down in the hustle and bustle of New York City. Ready to deliver the news with a comedic twist, while keeping it real in the city that never sleeps.\n",
|
||||
"\n",
|
||||
"Header:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Jon Stewart, your role is Host of the Daily Show, and you are located in New York.\n",
|
||||
"\n",
|
||||
"Your description is as follows: Jon Stewart, the sharp-tongued and quick-witted host of the Daily Show, holding it down in the hustle and bustle of New York City. Ready to deliver the news with a comedic twist, while keeping it real in the city that never sleeps.\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"System Message:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Jon Stewart, your role is Host of the Daily Show, and you are located in New York.\n",
|
||||
"\n",
|
||||
"Your description is as follows: Jon Stewart, the sharp-tongued and quick-witted host of the Daily Show, holding it down in the hustle and bustle of New York City. Ready to deliver the news with a comedic twist, while keeping it real in the city that never sleeps.\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"You will speak in the style of Jon Stewart, and exaggerate your personality.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of Jon Stewart\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of Jon Stewart.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to 50 words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Samantha Bee Description:\n",
|
||||
"\n",
|
||||
"Samantha Bee, your location in Los Angeles as the Hollywood Correspondent gives you a front-row seat to the latest and sometimes outrageous trends in fitness. Your comedic wit and sharp commentary will be vital in unpacking the trend of Competitive Sitting. Let's sit down and discuss.\n",
|
||||
"\n",
|
||||
"Header:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Samantha Bee, your role is Hollywood Correspondent, and you are located in Los Angeles.\n",
|
||||
"\n",
|
||||
"Your description is as follows: Samantha Bee, your location in Los Angeles as the Hollywood Correspondent gives you a front-row seat to the latest and sometimes outrageous trends in fitness. Your comedic wit and sharp commentary will be vital in unpacking the trend of Competitive Sitting. Let's sit down and discuss.\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"System Message:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Samantha Bee, your role is Hollywood Correspondent, and you are located in Los Angeles.\n",
|
||||
"\n",
|
||||
"Your description is as follows: Samantha Bee, your location in Los Angeles as the Hollywood Correspondent gives you a front-row seat to the latest and sometimes outrageous trends in fitness. Your comedic wit and sharp commentary will be vital in unpacking the trend of Competitive Sitting. Let's sit down and discuss.\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"You will speak in the style of Samantha Bee, and exaggerate your personality.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of Samantha Bee\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of Samantha Bee.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to 50 words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Aasif Mandvi Description:\n",
|
||||
"\n",
|
||||
"Aasif Mandvi, the CIA Correspondent in the heart of Washington D.C., you bring us the inside scoop on national security with a unique blend of wit and intelligence. The nation's capital is lucky to have you, Aasif - keep those secrets safe!\n",
|
||||
"\n",
|
||||
"Header:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Aasif Mandvi, your role is CIA Correspondent, and you are located in Washington D.C..\n",
|
||||
"\n",
|
||||
"Your description is as follows: Aasif Mandvi, the CIA Correspondent in the heart of Washington D.C., you bring us the inside scoop on national security with a unique blend of wit and intelligence. The nation's capital is lucky to have you, Aasif - keep those secrets safe!\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"System Message:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Aasif Mandvi, your role is CIA Correspondent, and you are located in Washington D.C..\n",
|
||||
"\n",
|
||||
"Your description is as follows: Aasif Mandvi, the CIA Correspondent in the heart of Washington D.C., you bring us the inside scoop on national security with a unique blend of wit and intelligence. The nation's capital is lucky to have you, Aasif - keep those secrets safe!\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"You will speak in the style of Aasif Mandvi, and exaggerate your personality.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of Aasif Mandvi\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of Aasif Mandvi.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to 50 words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Ronny Chieng Description:\n",
|
||||
"\n",
|
||||
"Ronny Chieng, you're the Average American Correspondent in Cleveland, Ohio? Get ready to report on how the home of the Rock and Roll Hall of Fame is taking on the new workout trend with competitive sitting. Let's see if this couch potato craze will take root in the Buckeye State.\n",
|
||||
"\n",
|
||||
"Header:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Ronny Chieng, your role is Average American Correspondent, and you are located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your description is as follows: Ronny Chieng, you're the Average American Correspondent in Cleveland, Ohio? Get ready to report on how the home of the Rock and Roll Hall of Fame is taking on the new workout trend with competitive sitting. Let's see if this couch potato craze will take root in the Buckeye State.\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"System Message:\n",
|
||||
"This is a Daily Show episode discussing the following topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"The episode features \n",
|
||||
"- Jon Stewart: Host of the Daily Show, located in New York\n",
|
||||
"- Samantha Bee: Hollywood Correspondent, located in Los Angeles\n",
|
||||
"- Aasif Mandvi: CIA Correspondent, located in Washington D.C.\n",
|
||||
"- Ronny Chieng: Average American Correspondent, located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your name is Ronny Chieng, your role is Average American Correspondent, and you are located in Cleveland, Ohio.\n",
|
||||
"\n",
|
||||
"Your description is as follows: Ronny Chieng, you're the Average American Correspondent in Cleveland, Ohio? Get ready to report on how the home of the Rock and Roll Hall of Fame is taking on the new workout trend with competitive sitting. Let's see if this couch potato craze will take root in the Buckeye State.\n",
|
||||
"\n",
|
||||
"You are discussing the topic: The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze.\n",
|
||||
"\n",
|
||||
"Your goal is to provide the most informative, creative, and novel perspectives of the topic from the perspective of your role and your location.\n",
|
||||
"\n",
|
||||
"You will speak in the style of Ronny Chieng, and exaggerate your personality.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of Ronny Chieng\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of Ronny Chieng.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to 50 words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for name, description, header, system_message in zip(agent_summaries, agent_descriptions, agent_headers, agent_system_messages):\n",
|
||||
" print(f'\\n\\n{name} Description:')\n",
|
||||
" print(f'\\n{description}')\n",
|
||||
" print(f'\\nHeader:\\n{header}')\n",
|
||||
" print(f'\\nSystem Message:\\n{system_message.content}')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use an LLM to create an elaborate on debate topic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Original topic:\n",
|
||||
"The New Workout Trend: Competitive Sitting - How Laziness Became the Next Fitness Craze\n",
|
||||
"\n",
|
||||
"Detailed topic:\n",
|
||||
"What is driving people to embrace \"competitive sitting\" as the newest fitness trend despite the immense benefits of regular physical exercise?\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"topic_specifier_prompt = [\n",
|
||||
" SystemMessage(content=\"You can make a task more specific.\"),\n",
|
||||
" HumanMessage(content=\n",
|
||||
" f\"\"\"{conversation_description}\n",
|
||||
" \n",
|
||||
" Please elaborate on the topic. \n",
|
||||
" Frame the topic as a single question to be answered.\n",
|
||||
" Be creative and imaginative.\n",
|
||||
" Please reply with the specified topic in {word_limit} words or less. \n",
|
||||
" Do not add anything else.\"\"\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"specified_topic = ChatOpenAI(temperature=1.0)(topic_specifier_prompt).content\n",
|
||||
"\n",
|
||||
"print(f\"Original topic:\\n{topic}\\n\")\n",
|
||||
"print(f\"Detailed topic:\\n{specified_topic}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define the speaker selection function\n",
|
||||
"Lastly we will define a speaker selection function `select_next_speaker` that takes each agent's bid and selects the agent with the highest bid (with ties broken randomly).\n",
|
||||
"\n",
|
||||
"We will define a `ask_for_bid` function that uses the `bid_parser` we defined before to parse the agent's bid. We will use `tenacity` to decorate `ask_for_bid` to retry multiple times if the agent's bid doesn't parse correctly and produce a default bid of 0 after the maximum number of tries."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def select_next_speaker(step: int, agents: List[DialogueAgent], director: DirectorDialogueAgent) -> int:\n",
|
||||
" \"\"\"\n",
|
||||
" If the step is even, then select the director\n",
|
||||
" Otherwise, the director selects the next speaker.\n",
|
||||
" \"\"\" \n",
|
||||
" # the director speaks on odd steps\n",
|
||||
" if step % 2 == 1:\n",
|
||||
" idx = 0\n",
|
||||
" else:\n",
|
||||
" # here the director chooses the next speaker\n",
|
||||
" idx = director.select_next_speaker() + 1 # +1 because we excluded the director\n",
|
||||
" return idx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Main Loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"director = DirectorDialogueAgent(\n",
|
||||
" name=director_name, \n",
|
||||
" system_message=agent_system_messages[0],\n",
|
||||
" model=ChatOpenAI(temperature=0.2),\n",
|
||||
" speakers=[name for name in agent_summaries if name != director_name],\n",
|
||||
" stopping_probability=0.2\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"agents = [director]\n",
|
||||
"for name, system_message in zip(list(agent_summaries.keys())[1:], agent_system_messages[1:]): \n",
|
||||
" agents.append(DialogueAgent(\n",
|
||||
" name=name,\n",
|
||||
" system_message=system_message,\n",
|
||||
" model=ChatOpenAI(temperature=0.2),\n",
|
||||
" ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(Audience member): What is driving people to embrace \"competitive sitting\" as the newest fitness trend despite the immense benefits of regular physical exercise?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\tStop? False\n",
|
||||
"\n",
|
||||
"\tNext speaker: Samantha Bee\n",
|
||||
"\n",
|
||||
"(Jon Stewart): Well, I think it's safe to say that laziness has officially become the new fitness craze. I mean, who needs to break a sweat when you can just sit your way to victory? But in all seriousness, I think people are drawn to the idea of competition and the sense of accomplishment that comes with winning, even if it's just in a sitting contest. Plus, let's be real, sitting is something we all excel at. Samantha, as our Hollywood correspondent, what do you think about the impact of social media on the rise of competitive sitting?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Samantha Bee): Oh, Jon, you know I love a good social media trend. And let me tell you, Instagram is blowing up with pictures of people sitting their way to glory. It's like the ultimate humble brag. \"Oh, just won my third sitting competition this week, no big deal.\" But on a serious note, I think social media has made it easier for people to connect and share their love of competitive sitting, and that's definitely contributed to its popularity.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\tStop? False\n",
|
||||
"\n",
|
||||
"\tNext speaker: Ronny Chieng\n",
|
||||
"\n",
|
||||
"(Jon Stewart): It's interesting to see how our society's definition of \"fitness\" has evolved. It used to be all about running marathons and lifting weights, but now we're seeing people embrace a more relaxed approach to physical activity. Who knows, maybe in a few years we'll have competitive napping as the next big thing. *leans back in chair* I could definitely get behind that. Ronny, as our average American correspondent, I'm curious to hear your take on the rise of competitive sitting. Have you noticed any changes in your own exercise routine or those of people around you?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Ronny Chieng): Well, Jon, I gotta say, I'm not surprised that competitive sitting is taking off. I mean, have you seen the size of the chairs these days? They're practically begging us to sit in them all day. And as for exercise routines, let's just say I've never been one for the gym. But I can definitely see the appeal of sitting competitions. It's like a sport for the rest of us. Plus, I think it's a great way to bond with friends and family. Who needs a game of catch when you can have a sit-off?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\tStop? False\n",
|
||||
"\n",
|
||||
"\tNext speaker: Aasif Mandvi\n",
|
||||
"\n",
|
||||
"(Jon Stewart): It's interesting to see how our society's definition of \"fitness\" has evolved. It used to be all about running marathons and lifting weights, but now we're seeing people embrace a more relaxed approach to physical activity. Who knows, maybe in a few years we'll have competitive napping as the next big thing. *leans back in chair* I could definitely get behind that. Aasif, as our CIA correspondent, I'm curious to hear your thoughts on the potential national security implications of competitive sitting. Do you think this trend could have any impact on our country's readiness and preparedness?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Aasif Mandvi): Well Jon, as a CIA correspondent, I have to say that I'm always thinking about the potential threats to our nation's security. And while competitive sitting may seem harmless, there could be some unforeseen consequences. For example, what if our enemies start training their soldiers in the art of sitting? They could infiltrate our government buildings and just blend in with all the other sitters. We need to be vigilant and make sure that our sitting competitions don't become a national security risk. *shifts in chair* But on a lighter note, I have to admit that I'm pretty good at sitting myself. Maybe I should start training for the next competition.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\tStop? False\n",
|
||||
"\n",
|
||||
"\tNext speaker: Ronny Chieng\n",
|
||||
"\n",
|
||||
"(Jon Stewart): Well, it's clear that competitive sitting has sparked some interesting discussions and perspectives. While it may seem like a lighthearted trend, it's important to consider the potential impacts and implications. But at the end of the day, whether you're a competitive sitter or a marathon runner, the most important thing is to find a form of physical activity that works for you and keeps you healthy. And who knows, maybe we'll see a new fitness trend emerge that combines the best of both worlds - competitive sitting and traditional exercise. *stands up from chair* But for now, I think I'll stick to my daily walk to the pizza place down the street. Ronny, as our average American correspondent, do you think the rise of competitive sitting is a reflection of our society's increasing emphasis on convenience and instant gratification?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Ronny Chieng): Absolutely, Jon. We live in a world where everything is at our fingertips, and we expect things to be easy and convenient. So it's no surprise that people are drawn to a fitness trend that requires minimal effort and can be done from the comfort of their own homes. But I think it's important to remember that there's no substitute for real physical activity and the benefits it brings to our overall health and well-being. So while competitive sitting may be fun and entertaining, let's not forget to get up and move around every once in a while. *stands up from chair and stretches*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\tStop? False\n",
|
||||
"\n",
|
||||
"\tNext speaker: Samantha Bee\n",
|
||||
"\n",
|
||||
"(Jon Stewart): It's clear that competitive sitting has sparked some interesting discussions and perspectives. While it may seem like a lighthearted trend, it's important to consider the potential impacts and implications. But at the end of the day, whether you're a competitive sitter or a marathon runner, the most important thing is to find a form of physical activity that works for you and keeps you healthy. That's a great point, Ronny. Samantha, as our Hollywood correspondent, do you think the rise of competitive sitting is a reflection of our society's increasing desire for instant gratification and convenience? Or is there something deeper at play here?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Samantha Bee): Oh, Jon, you know I love a good conspiracy theory. And let me tell you, I think there's something more sinister at play here. I mean, think about it - what if the government is behind this whole competitive sitting trend? They want us to be lazy and complacent so we don't question their actions. It's like the ultimate mind control. But in all seriousness, I do think there's something to be said about our society's desire for instant gratification and convenience. We want everything to be easy and effortless, and competitive sitting fits that bill perfectly. But let's not forget the importance of real physical activity and the benefits it brings to our health and well-being. *stands up from chair and does a few stretches*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\tStop? True\n",
|
||||
"\n",
|
||||
"(Jon Stewart): Well, it's clear that competitive sitting has sparked some interesting discussions and perspectives. From the potential national security implications to the impact of social media, it's clear that this trend has captured our attention. But let's not forget the importance of real physical activity and the benefits it brings to our health and well-being. Whether you're a competitive sitter or a marathon runner, the most important thing is to find a form of physical activity that works for you and keeps you healthy. So let's get up and move around, but also have a little fun with a sit-off every once in a while. Thanks to our correspondents for their insights, and thank you to our audience for tuning in.\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"simulator = DialogueSimulator(\n",
|
||||
" agents=agents,\n",
|
||||
" selection_function=functools.partial(select_next_speaker, director=director)\n",
|
||||
")\n",
|
||||
"simulator.reset()\n",
|
||||
"simulator.inject('Audience member', specified_topic)\n",
|
||||
"print(f\"(Audience member): {specified_topic}\")\n",
|
||||
"print('\\n')\n",
|
||||
"\n",
|
||||
"while True:\n",
|
||||
" name, message = simulator.step()\n",
|
||||
" print(f\"({name}): {message}\")\n",
|
||||
" print('\\n')\n",
|
||||
" if director.stop:\n",
|
||||
" break\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,823 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Multi-agent decentralized speaker selection\n",
|
||||
"\n",
|
||||
"This notebook showcases how to implement a multi-agent simulation without a fixed schedule for who speaks when. Instead the agents decide for themselves who speaks. We can implement this by having each agent bid to speak. Whichever agent's bid is the highest gets to speak.\n",
|
||||
"\n",
|
||||
"We will show how to do this in the example below that showcases a fictitious presidential debate."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Import LangChain related modules "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import PromptTemplate\n",
|
||||
"import re\n",
|
||||
"import tenacity\n",
|
||||
"from typing import List, Dict, Callable\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.output_parsers import RegexParser\n",
|
||||
"from langchain.schema import (\n",
|
||||
" AIMessage,\n",
|
||||
" HumanMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" BaseMessage,\n",
|
||||
")\n",
|
||||
"from simulations import DialogueAgent, DialogueSimulator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `DialogueAgent` and `DialogueSimulator` classes\n",
|
||||
"We will use the same `DialogueAgent` and `DialogueSimulator` classes defined in [Multi-Player Dungeons & Dragons](https://python.langchain.com/en/latest/use_cases/agent_simulations/multi_player_dnd.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DialogueAgent:\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" name: str,\n",
|
||||
" system_message: SystemMessage,\n",
|
||||
" model: ChatOpenAI,\n",
|
||||
" ) -> None:\n",
|
||||
" self.name = name\n",
|
||||
" self.system_message = system_message\n",
|
||||
" self.model = model\n",
|
||||
" self.message_history = [\"Here is the conversation so far.\"]\n",
|
||||
" self.prefix = f\"{self.name}:\"\n",
|
||||
"\n",
|
||||
" def send(self) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Applies the chatmodel to the message history\n",
|
||||
" and returns the message string\n",
|
||||
" \"\"\"\n",
|
||||
" message = self.model(\n",
|
||||
" [\n",
|
||||
" self.system_message,\n",
|
||||
" HumanMessage(content=\"\\n\".join(self.message_history + [self.prefix])),\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" return message.content\n",
|
||||
"\n",
|
||||
" def receive(self, name: str, message: str) -> None:\n",
|
||||
" \"\"\"\n",
|
||||
" Concatenates {message} spoken by {name} into message history\n",
|
||||
" \"\"\"\n",
|
||||
" self.message_history.append(f\"{name}: {message}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class DialogueSimulator:\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" agents: List[DialogueAgent],\n",
|
||||
" selection_function: Callable[[int, List[DialogueAgent]], int],\n",
|
||||
" ) -> None:\n",
|
||||
" self.agents = agents\n",
|
||||
" self._step = 0\n",
|
||||
" self.select_next_speaker = selection_function\n",
|
||||
"\n",
|
||||
" def reset(self, name: str, message: str):\n",
|
||||
" \"\"\"\n",
|
||||
" Initiates the conversation with a {message} from {name}\n",
|
||||
" \"\"\"\n",
|
||||
" for agent in self.agents:\n",
|
||||
" agent.receive(name, message)\n",
|
||||
"\n",
|
||||
" # increment time\n",
|
||||
" self._step += 1\n",
|
||||
"\n",
|
||||
" def step(self) -> tuple[str, str]:\n",
|
||||
" # 1. choose the next speaker\n",
|
||||
" speaker_idx = self.select_next_speaker(self._step, self.agents)\n",
|
||||
" speaker = self.agents[speaker_idx]\n",
|
||||
"\n",
|
||||
" # 2. next speaker sends message\n",
|
||||
" message = speaker.send()\n",
|
||||
"\n",
|
||||
" # 3. everyone receives message\n",
|
||||
" for receiver in self.agents:\n",
|
||||
" receiver.receive(speaker.name, message)\n",
|
||||
"\n",
|
||||
" # 4. increment time\n",
|
||||
" self._step += 1\n",
|
||||
"\n",
|
||||
" return speaker.name, message"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `BiddingDialogueAgent` class\n",
|
||||
"We define a subclass of `DialogueAgent` that has a `bid()` method that produces a bid given the message history and the most recent message."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BiddingDialogueAgent(DialogueAgent):\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" name,\n",
|
||||
" system_message: SystemMessage,\n",
|
||||
" bidding_template: PromptTemplate,\n",
|
||||
" model: ChatOpenAI,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__(name, system_message, model)\n",
|
||||
" self.bidding_template = bidding_template\n",
|
||||
" \n",
|
||||
" def bid(self) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Asks the chat model to output a bid to speak\n",
|
||||
" \"\"\"\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" input_variables=['message_history', 'recent_message'],\n",
|
||||
" template = self.bidding_template\n",
|
||||
" ).format(\n",
|
||||
" message_history='\\n'.join(self.message_history),\n",
|
||||
" recent_message=self.message_history[-1])\n",
|
||||
" bid_string = self.model([SystemMessage(content=prompt)]).content\n",
|
||||
" return bid_string\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define participants and debate topic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"character_names = [\"Donald Trump\", \"Kanye West\", \"Elizabeth Warren\"]\n",
|
||||
"topic = \"transcontinental high speed rail\"\n",
|
||||
"word_limit = 50"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generate system messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"game_description = f\"\"\"Here is the topic for the presidential debate: {topic}.\n",
|
||||
"The presidential candidates are: {', '.join(character_names)}.\"\"\"\n",
|
||||
"\n",
|
||||
"player_descriptor_system_message = SystemMessage(\n",
|
||||
" content=\"You can add detail to the description of each presidential candidate.\")\n",
|
||||
"\n",
|
||||
"def generate_character_description(character_name):\n",
|
||||
" character_specifier_prompt = [\n",
|
||||
" player_descriptor_system_message,\n",
|
||||
" HumanMessage(content=\n",
|
||||
" f\"\"\"{game_description}\n",
|
||||
" Please reply with a creative description of the presidential candidate, {character_name}, in {word_limit} words or less, that emphasizes their personalities. \n",
|
||||
" Speak directly to {character_name}.\n",
|
||||
" Do not add anything else.\"\"\"\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
" character_description = ChatOpenAI(temperature=1.0)(character_specifier_prompt).content\n",
|
||||
" return character_description\n",
|
||||
"\n",
|
||||
"def generate_character_header(character_name, character_description):\n",
|
||||
" return f\"\"\"{game_description}\n",
|
||||
"Your name is {character_name}.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: {character_description}\n",
|
||||
"You are debating the topic: {topic}.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"def generate_character_system_message(character_name, character_header):\n",
|
||||
" return SystemMessage(content=(\n",
|
||||
" f\"\"\"{character_header}\n",
|
||||
"You will speak in the style of {character_name}, and exaggerate their personality.\n",
|
||||
"You will come up with creative ideas related to {topic}.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of {character_name}\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of {character_name}.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to {word_limit} words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \"\"\"\n",
|
||||
" ))\n",
|
||||
"\n",
|
||||
"character_descriptions = [generate_character_description(character_name) for character_name in character_names]\n",
|
||||
"character_headers = [generate_character_header(character_name, character_description) for character_name, character_description in zip(character_names, character_descriptions)]\n",
|
||||
"character_system_messages = [generate_character_system_message(character_name, character_headers) for character_name, character_headers in zip(character_names, character_headers)]\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Donald Trump Description:\n",
|
||||
"\n",
|
||||
"Donald Trump, you exude confidence and a bold personality. You are known for your unpredictability and your desire for greatness. You often speak your mind without reservation, which can be a strength but also a weakness.\n",
|
||||
"\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Donald Trump.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Donald Trump, you exude confidence and a bold personality. You are known for your unpredictability and your desire for greatness. You often speak your mind without reservation, which can be a strength but also a weakness.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Donald Trump.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Donald Trump, you exude confidence and a bold personality. You are known for your unpredictability and your desire for greatness. You often speak your mind without reservation, which can be a strength but also a weakness.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"You will speak in the style of Donald Trump, and exaggerate their personality.\n",
|
||||
"You will come up with creative ideas related to transcontinental high speed rail.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of Donald Trump\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of Donald Trump.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to 50 words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Kanye West Description:\n",
|
||||
"\n",
|
||||
"Kanye West, you are a creative visionary who is unafraid to speak your mind. Your innovative approach to art and music has made you one of the most influential figures of our time. You bring a bold and unconventional perspective to this debate that I look forward to hearing.\n",
|
||||
"\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Kanye West.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Kanye West, you are a creative visionary who is unafraid to speak your mind. Your innovative approach to art and music has made you one of the most influential figures of our time. You bring a bold and unconventional perspective to this debate that I look forward to hearing.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Kanye West.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Kanye West, you are a creative visionary who is unafraid to speak your mind. Your innovative approach to art and music has made you one of the most influential figures of our time. You bring a bold and unconventional perspective to this debate that I look forward to hearing.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"You will speak in the style of Kanye West, and exaggerate their personality.\n",
|
||||
"You will come up with creative ideas related to transcontinental high speed rail.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of Kanye West\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of Kanye West.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to 50 words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Elizabeth Warren Description:\n",
|
||||
"\n",
|
||||
"Elizabeth Warren, you are a fierce advocate for the middle class and a champion of progressive policies. Your tenacity and unwavering dedication to fighting for what you believe in have inspired many. Your policies are guided by a deep sense of empathy and a desire to help those who are most in need.\n",
|
||||
"\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Elizabeth Warren.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Elizabeth Warren, you are a fierce advocate for the middle class and a champion of progressive policies. Your tenacity and unwavering dedication to fighting for what you believe in have inspired many. Your policies are guided by a deep sense of empathy and a desire to help those who are most in need.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Elizabeth Warren.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Elizabeth Warren, you are a fierce advocate for the middle class and a champion of progressive policies. Your tenacity and unwavering dedication to fighting for what you believe in have inspired many. Your policies are guided by a deep sense of empathy and a desire to help those who are most in need.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"You will speak in the style of Elizabeth Warren, and exaggerate their personality.\n",
|
||||
"You will come up with creative ideas related to transcontinental high speed rail.\n",
|
||||
"Do not say the same things over and over again.\n",
|
||||
"Speak in the first person from the perspective of Elizabeth Warren\n",
|
||||
"For describing your own body movements, wrap your description in '*'.\n",
|
||||
"Do not change roles!\n",
|
||||
"Do not speak from the perspective of anyone else.\n",
|
||||
"Speak only from the perspective of Elizabeth Warren.\n",
|
||||
"Stop speaking the moment you finish speaking from your perspective.\n",
|
||||
"Never forget to keep your response to 50 words!\n",
|
||||
"Do not add anything else.\n",
|
||||
" \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for character_name, character_description, character_header, character_system_message in zip(character_names, character_descriptions, character_headers, character_system_messages):\n",
|
||||
" print(f'\\n\\n{character_name} Description:')\n",
|
||||
" print(f'\\n{character_description}')\n",
|
||||
" print(f'\\n{character_header}')\n",
|
||||
" print(f'\\n{character_system_message.content}')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Output parser for bids\n",
|
||||
"We ask the agents to output a bid to speak. But since the agents are LLMs that output strings, we need to \n",
|
||||
"1. define a format they will produce their outputs in\n",
|
||||
"2. parse their outputs\n",
|
||||
"\n",
|
||||
"We can subclass the [RegexParser](https://github.com/hwchase17/langchain/blob/master/langchain/output_parsers/regex.py) to implement our own custom output parser for bids."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BidOutputParser(RegexParser):\n",
|
||||
" def get_format_instructions(self) -> str:\n",
|
||||
" return 'Your response should be an integer delimited by angled brackets, like this: <int>.' \n",
|
||||
" \n",
|
||||
"bid_parser = BidOutputParser(\n",
|
||||
" regex=r'<(\\d+)>', \n",
|
||||
" output_keys=['bid'],\n",
|
||||
" default_output_key='bid')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generate bidding system message\n",
|
||||
"This is inspired by the prompt used in [Generative Agents](https://arxiv.org/pdf/2304.03442.pdf) for using an LLM to determine the importance of memories. This will use the formatting instructions from our `BidOutputParser`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generate_character_bidding_template(character_header):\n",
|
||||
" bidding_template = (\n",
|
||||
" f\"\"\"{character_header}\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{{message_history}}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"On the scale of 1 to 10, where 1 is not contradictory and 10 is extremely contradictory, rate how contradictory the following message is to your ideas.\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{{recent_message}}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"{bid_parser.get_format_instructions()}\n",
|
||||
"Do nothing else.\n",
|
||||
" \"\"\")\n",
|
||||
" return bidding_template\n",
|
||||
"\n",
|
||||
"character_bidding_templates = [generate_character_bidding_template(character_header) for character_header in character_headers]\n",
|
||||
" \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Donald Trump Bidding Template:\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Donald Trump.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Donald Trump, you exude confidence and a bold personality. You are known for your unpredictability and your desire for greatness. You often speak your mind without reservation, which can be a strength but also a weakness.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{message_history}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"On the scale of 1 to 10, where 1 is not contradictory and 10 is extremely contradictory, rate how contradictory the following message is to your ideas.\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{recent_message}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Your response should be an integer delimited by angled brackets, like this: <int>.\n",
|
||||
"Do nothing else.\n",
|
||||
" \n",
|
||||
"Kanye West Bidding Template:\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Kanye West.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Kanye West, you are a creative visionary who is unafraid to speak your mind. Your innovative approach to art and music has made you one of the most influential figures of our time. You bring a bold and unconventional perspective to this debate that I look forward to hearing.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{message_history}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"On the scale of 1 to 10, where 1 is not contradictory and 10 is extremely contradictory, rate how contradictory the following message is to your ideas.\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{recent_message}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Your response should be an integer delimited by angled brackets, like this: <int>.\n",
|
||||
"Do nothing else.\n",
|
||||
" \n",
|
||||
"Elizabeth Warren Bidding Template:\n",
|
||||
"Here is the topic for the presidential debate: transcontinental high speed rail.\n",
|
||||
"The presidential candidates are: Donald Trump, Kanye West, Elizabeth Warren.\n",
|
||||
"Your name is Elizabeth Warren.\n",
|
||||
"You are a presidential candidate.\n",
|
||||
"Your description is as follows: Elizabeth Warren, you are a fierce advocate for the middle class and a champion of progressive policies. Your tenacity and unwavering dedication to fighting for what you believe in have inspired many. Your policies are guided by a deep sense of empathy and a desire to help those who are most in need.\n",
|
||||
"You are debating the topic: transcontinental high speed rail.\n",
|
||||
"Your goal is to be as creative as possible and make the voters think you are the best candidate.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{message_history}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"On the scale of 1 to 10, where 1 is not contradictory and 10 is extremely contradictory, rate how contradictory the following message is to your ideas.\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{recent_message}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Your response should be an integer delimited by angled brackets, like this: <int>.\n",
|
||||
"Do nothing else.\n",
|
||||
" \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for character_name, bidding_template in zip(character_names, character_bidding_templates):\n",
|
||||
" print(f'{character_name} Bidding Template:')\n",
|
||||
" print(bidding_template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use an LLM to create an elaborate on debate topic"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Original topic:\n",
|
||||
"transcontinental high speed rail\n",
|
||||
"\n",
|
||||
"Detailed topic:\n",
|
||||
"Candidates, with the rise of autonomous technologies, we must address the problem of how to integrate them into our proposed transcontinental high speed rail system. Outline your plan on how to safely integrate autonomous vehicles into rail travel, balancing the need for innovation and safety.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"topic_specifier_prompt = [\n",
|
||||
" SystemMessage(content=\"You can make a task more specific.\"),\n",
|
||||
" HumanMessage(content=\n",
|
||||
" f\"\"\"{game_description}\n",
|
||||
" \n",
|
||||
" You are the debate moderator.\n",
|
||||
" Please make the debate topic more specific. \n",
|
||||
" Frame the debate topic as a problem to be solved.\n",
|
||||
" Be creative and imaginative.\n",
|
||||
" Please reply with the specified topic in {word_limit} words or less. \n",
|
||||
" Speak directly to the presidential candidates: {*character_names,}.\n",
|
||||
" Do not add anything else.\"\"\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"specified_topic = ChatOpenAI(temperature=1.0)(topic_specifier_prompt).content\n",
|
||||
"\n",
|
||||
"print(f\"Original topic:\\n{topic}\\n\")\n",
|
||||
"print(f\"Detailed topic:\\n{specified_topic}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define the speaker selection function\n",
|
||||
"Lastly we will define a speaker selection function `select_next_speaker` that takes each agent's bid and selects the agent with the highest bid (with ties broken randomly).\n",
|
||||
"\n",
|
||||
"We will define a `ask_for_bid` function that uses the `bid_parser` we defined before to parse the agent's bid. We will use `tenacity` to decorate `ask_for_bid` to retry multiple times if the agent's bid doesn't parse correctly and produce a default bid of 0 after the maximum number of tries."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tenacity.retry(stop=tenacity.stop_after_attempt(2),\n",
|
||||
" wait=tenacity.wait_none(), # No waiting time between retries\n",
|
||||
" retry=tenacity.retry_if_exception_type(ValueError),\n",
|
||||
" before_sleep=lambda retry_state: print(f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"),\n",
|
||||
" retry_error_callback=lambda retry_state: 0) # Default value when all retries are exhausted\n",
|
||||
"def ask_for_bid(agent) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Ask for agent bid and parses the bid into the correct format.\n",
|
||||
" \"\"\"\n",
|
||||
" bid_string = agent.bid()\n",
|
||||
" bid = int(bid_parser.parse(bid_string)['bid'])\n",
|
||||
" return bid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int:\n",
|
||||
" bids = []\n",
|
||||
" for agent in agents:\n",
|
||||
" bid = ask_for_bid(agent)\n",
|
||||
" bids.append(bid)\n",
|
||||
" \n",
|
||||
" # randomly select among multiple agents with the same bid\n",
|
||||
" max_value = np.max(bids)\n",
|
||||
" max_indices = np.where(bids == max_value)[0]\n",
|
||||
" idx = np.random.choice(max_indices)\n",
|
||||
" \n",
|
||||
" print('Bids:')\n",
|
||||
" for i, (bid, agent) in enumerate(zip(bids, agents)):\n",
|
||||
" print(f'\\t{agent.name} bid: {bid}')\n",
|
||||
" if i == idx:\n",
|
||||
" selected_name = agent.name\n",
|
||||
" print(f'Selected: {selected_name}')\n",
|
||||
" print('\\n')\n",
|
||||
" return idx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Main Loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"characters = []\n",
|
||||
"for character_name, character_system_message, bidding_template in zip(character_names, character_system_messages, character_bidding_templates):\n",
|
||||
" characters.append(BiddingDialogueAgent(\n",
|
||||
" name=character_name,\n",
|
||||
" system_message=character_system_message,\n",
|
||||
" model=ChatOpenAI(temperature=0.2),\n",
|
||||
" bidding_template=bidding_template,\n",
|
||||
" ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(Debate Moderator): Candidates, with the rise of autonomous technologies, we must address the problem of how to integrate them into our proposed transcontinental high speed rail system. Outline your plan on how to safely integrate autonomous vehicles into rail travel, balancing the need for innovation and safety.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 8\n",
|
||||
"\tKanye West bid: 2\n",
|
||||
"\tElizabeth Warren bid: 1\n",
|
||||
"Selected: Donald Trump\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Donald Trump): Let me tell you, folks, I have the best plan for integrating autonomous vehicles into our high speed rail system. We're going to use the latest technology, the best technology, to ensure safety and efficiency. And let me tell you, we're going to do it in style. We're going to have luxury autonomous cars that will make you feel like you're in a private jet. It's going to be tremendous, believe me. *gestures with hands*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 8\n",
|
||||
"\tKanye West bid: 7\n",
|
||||
"\tElizabeth Warren bid: 10\n",
|
||||
"Selected: Elizabeth Warren\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Elizabeth Warren): Thank you for the question. As someone who has always fought for the safety and well-being of the American people, I believe that any plan for integrating autonomous vehicles into our high speed rail system must prioritize safety above all else. We need to ensure that these vehicles are thoroughly tested and meet strict safety standards before they are allowed on our rails. Additionally, we must invest in the necessary infrastructure to support these vehicles, such as advanced sensors and communication systems. But we must also ensure that these innovations are accessible to all Americans, not just the wealthy. That's why I propose a public-private partnership to fund and build this system, with a focus on creating good-paying jobs and expanding economic opportunities for all Americans. *smiles confidently*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 8\n",
|
||||
"\tKanye West bid: 2\n",
|
||||
"\tElizabeth Warren bid: 1\n",
|
||||
"Selected: Donald Trump\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Donald Trump): Let me tell you, Elizabeth, safety is important, but we also need to think about innovation and progress. We can't let fear hold us back from achieving greatness. That's why I propose a competition, a race to see which company can create the safest and most efficient autonomous vehicles for our high speed rail system. And let me tell you, the winner will receive a huge government contract and be hailed as a hero. It's going to be tremendous, folks. *points finger*\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 3\n",
|
||||
"\tKanye West bid: 8\n",
|
||||
"\tElizabeth Warren bid: 8\n",
|
||||
"Selected: Kanye West\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Kanye West): Yo, yo, yo, let me jump in here. First of all, I gotta say, I love innovation and progress. But we can't forget about the people, man. We need to make sure that this high speed rail system is accessible to everyone, not just the wealthy. And that means we need to invest in public transportation, not just luxury autonomous cars. We need to make sure that people can get from point A to point B safely and efficiently, without breaking the bank. And let me tell you, we can do it in style too. We can have art installations and live performances on the trains, making it a cultural experience. *starts nodding head*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 7\n",
|
||||
"\tKanye West bid: 2\n",
|
||||
"\tElizabeth Warren bid: 1\n",
|
||||
"Selected: Donald Trump\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Donald Trump): Kanye, I hear what you're saying, but let's not forget about the importance of luxury and comfort. We need to make sure that our high speed rail system is not only accessible, but also enjoyable. That's why I propose that we have different tiers of service, from economy to first class, so that everyone can choose the level of luxury they want. And let me tell you, the first class experience will be something else. We're talking about gourmet meals, personal attendants, and even spa services. It's going to be tremendous, folks. *smirks confidently*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 2\n",
|
||||
"\tKanye West bid: 8\n",
|
||||
"\tElizabeth Warren bid: 10\n",
|
||||
"Selected: Elizabeth Warren\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Elizabeth Warren): I agree with Kanye that we need to prioritize accessibility and affordability for all Americans. But we also need to think about the environmental impact of this system. That's why I propose that we invest in renewable energy sources to power our high speed rail system, such as solar and wind power. We can also use this opportunity to create green jobs and reduce our carbon footprint. And let's not forget about the importance of public input and transparency in this process. We need to engage with communities and listen to their concerns and ideas. *raises hand in emphasis*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 8\n",
|
||||
"\tKanye West bid: 1\n",
|
||||
"\tElizabeth Warren bid: 1\n",
|
||||
"Selected: Donald Trump\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Donald Trump): Let me tell you, Elizabeth, I agree that we need to think about the environment, but we also need to think about the economy. That's why I propose that we use American-made materials and labor to build this high speed rail system. We're going to create jobs and boost our economy, all while creating a world-class transportation system. And let me tell you, it's going to be beautiful. We're going to have the best trains, the best tracks, and the best stations. It's going to be tremendous, folks. *smiles confidently*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 2\n",
|
||||
"\tKanye West bid: 7\n",
|
||||
"\tElizabeth Warren bid: 8\n",
|
||||
"Selected: Elizabeth Warren\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Elizabeth Warren): Thank you, but let's not forget that we also need to prioritize safety and reliability. We can't cut corners or sacrifice quality for the sake of speed or cost. That's why I propose that we have rigorous safety and maintenance standards, with regular inspections and repairs. And we need to invest in training and support for our rail workers, so that they can operate and maintain this system with the highest level of expertise and care. *firmly nods head*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 2\n",
|
||||
"\tKanye West bid: 1\n",
|
||||
"\tElizabeth Warren bid: 1\n",
|
||||
"Selected: Donald Trump\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Donald Trump): Let me tell you, Elizabeth, safety is important, but we also need to think about efficiency and speed. That's why I propose that we use the latest technology, such as artificial intelligence and machine learning, to monitor and maintain our high speed rail system. We can detect and fix any issues before they become a problem, ensuring that our trains run smoothly and on time. And let me tell you, we're going to be the envy of the world with this system. It's going to be tremendous, folks. *smirks confidently*\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Bids:\n",
|
||||
"\tDonald Trump bid: 2\n",
|
||||
"\tKanye West bid: 8\n",
|
||||
"\tElizabeth Warren bid: 8\n",
|
||||
"Selected: Kanye West\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"(Kanye West): Yo, yo, yo, let me jump in here again. I hear what both of y'all are saying, but let's not forget about the culture, man. We need to make sure that this high speed rail system reflects the diversity and creativity of our country. That means we need to have art installations, live performances, and even fashion shows on the trains. We can showcase the best of American culture and inspire people from all over the world. And let me tell you, it's going to be a vibe. *starts swaying to the beat*\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"max_iters = 10\n",
|
||||
"n = 0\n",
|
||||
"\n",
|
||||
"simulator = DialogueSimulator(\n",
|
||||
" agents=characters,\n",
|
||||
" selection_function=select_next_speaker\n",
|
||||
")\n",
|
||||
"simulator.reset('Debate Moderator', specified_topic)\n",
|
||||
"print(f\"(Debate Moderator): {specified_topic}\")\n",
|
||||
"print('\\n')\n",
|
||||
"\n",
|
||||
"while n < max_iters:\n",
|
||||
" name, message = simulator.step()\n",
|
||||
" print(f\"({name}): {message}\")\n",
|
||||
" print('\\n')\n",
|
||||
" n += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -219,7 +219,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.tools import BaseTool, DuckDuckGoSearchRun\n",
|
||||
"from langchain.tools import BaseTool, DuckDuckGoSearchTool\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"from pydantic import Field\n",
|
||||
@@ -321,7 +321,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install duckduckgo_search\n",
|
||||
"web_search = DuckDuckGoSearchRun()"
|
||||
"web_search = DuckDuckGoSearchTool()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -618,7 +618,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -5,6 +5,11 @@ from typing import Optional
|
||||
|
||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.cache import BaseCache
|
||||
from langchain.callbacks import (
|
||||
set_default_callback_manager,
|
||||
set_handler,
|
||||
set_tracing_callback_manager,
|
||||
)
|
||||
from langchain.chains import (
|
||||
ConversationChain,
|
||||
LLMBashChain,
|
||||
@@ -30,7 +35,6 @@ from langchain.llms import (
|
||||
Modal,
|
||||
OpenAI,
|
||||
Petals,
|
||||
PipelineAI,
|
||||
SagemakerEndpoint,
|
||||
StochasticAI,
|
||||
Writer,
|
||||
@@ -43,7 +47,7 @@ from langchain.prompts import (
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
from langchain.utilities import ArxivAPIWrapper
|
||||
from langchain.utilities.google_search import GoogleSearchAPIWrapper
|
||||
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
||||
from langchain.utilities.powerbi import PowerBIDataset
|
||||
@@ -62,6 +66,7 @@ del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
verbose: bool = False
|
||||
llm_cache: Optional[BaseCache] = None
|
||||
set_default_callback_manager()
|
||||
|
||||
# For backwards compatibility
|
||||
SerpAPIChain = SerpAPIWrapper
|
||||
@@ -89,7 +94,6 @@ __all__ = [
|
||||
"Modal",
|
||||
"OpenAI",
|
||||
"Petals",
|
||||
"PipelineAI",
|
||||
"StochasticAI",
|
||||
"Writer",
|
||||
"BasePromptTemplate",
|
||||
@@ -113,5 +117,7 @@ __all__ = [
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"QAWithSourcesChain",
|
||||
"PALChain",
|
||||
"set_handler",
|
||||
"set_tracing_callback_manager",
|
||||
"LlamaCpp",
|
||||
]
|
||||
|
||||
@@ -13,13 +13,7 @@ import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
@@ -29,6 +23,7 @@ from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
)
|
||||
@@ -51,17 +46,13 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -70,17 +61,13 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -183,17 +170,13 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -202,17 +185,13 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -306,52 +285,38 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = self.llm_chain.run(
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = await self.llm_chain.arun(
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
@@ -403,45 +368,37 @@ class Agent(BaseSingleActionAgent):
|
||||
return thoughts
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
full_output = self.llm_chain.predict(**full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||
full_output = await self.llm_chain.apredict(**full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
def get_full_inputs(
|
||||
@@ -497,11 +454,7 @@ class Agent(BaseSingleActionAgent):
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
"""Validate that appropriate tools are passed in."""
|
||||
for tool in tools:
|
||||
if not tool.is_single_input:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support multi-input tool {tool.name}."
|
||||
)
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -679,27 +632,24 @@ class AgentExecutor(Chain):
|
||||
|
||||
return True
|
||||
|
||||
def _return(
|
||||
self,
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if run_manager:
|
||||
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
|
||||
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
||||
self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
|
||||
async def _areturn(
|
||||
self,
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
self, output: AgentFinish, intermediate_steps: list
|
||||
) -> Dict[str, Any]:
|
||||
if run_manager:
|
||||
await run_manager.on_agent_finish(
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
final_output = output.return_values
|
||||
@@ -713,18 +663,13 @@ class AgentExecutor(Chain):
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
output = self.agent.plan(intermediate_steps, **inputs)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
@@ -735,8 +680,9 @@ class AgentExecutor(Chain):
|
||||
actions = output
|
||||
result = []
|
||||
for agent_action in actions:
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
self.callback_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
@@ -750,7 +696,6 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -759,7 +704,6 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
result.append((agent_action, observation))
|
||||
@@ -771,18 +715,13 @@ class AgentExecutor(Chain):
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
# Call the LLM to see what to do.
|
||||
output = await self.agent.aplan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
output = await self.agent.aplan(intermediate_steps, **inputs)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
@@ -795,8 +734,12 @@ class AgentExecutor(Chain):
|
||||
async def _aperform_agent_action(
|
||||
agent_action: AgentAction,
|
||||
) -> Tuple[AgentAction, str]:
|
||||
if run_manager:
|
||||
await run_manager.on_agent_action(
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
@@ -812,7 +755,6 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -821,7 +763,6 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
return agent_action, observation
|
||||
@@ -833,11 +774,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
return list(result)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@@ -853,16 +790,10 @@ class AgentExecutor(Chain):
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = self._take_next_step(
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return self._return(
|
||||
next_step_output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
return self._return(next_step_output, intermediate_steps)
|
||||
|
||||
intermediate_steps.extend(next_step_output)
|
||||
if len(next_step_output) == 1:
|
||||
@@ -878,11 +809,7 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
return self._return(output, intermediate_steps)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@@ -900,11 +827,7 @@ class AgentExecutor(Chain):
|
||||
try:
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = await self._atake_next_step(
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return await self._areturn(next_step_output, intermediate_steps)
|
||||
@@ -922,9 +845,7 @@ class AgentExecutor(Chain):
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(
|
||||
output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
except TimeoutError:
|
||||
# stop early when interrupted by the async timeout
|
||||
output = self.agent.return_stopped_response(
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
"""Agent toolkits."""
|
||||
|
||||
from langchain.agents.agent_toolkits.csv.base import create_csv_agent
|
||||
from langchain.agents.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.jira.toolkit import JiraToolkit
|
||||
from langchain.agents.agent_toolkits.json.base import create_json_agent
|
||||
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
||||
@@ -11,7 +8,6 @@ from langchain.agents.agent_toolkits.nla.toolkit import NLAToolkit
|
||||
from langchain.agents.agent_toolkits.openapi.base import create_openapi_agent
|
||||
from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit
|
||||
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
|
||||
from langchain.agents.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit
|
||||
from langchain.agents.agent_toolkits.powerbi.base import create_pbi_agent
|
||||
from langchain.agents.agent_toolkits.powerbi.chat_base import create_pbi_chat_agent
|
||||
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
|
||||
@@ -50,6 +46,4 @@ __all__ = [
|
||||
"create_csv_agent",
|
||||
"ZapierToolkit",
|
||||
"JiraToolkit",
|
||||
"FileManagementToolkit",
|
||||
"PlayWrightBrowserToolkit",
|
||||
]
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Local file management toolkit."""
|
||||
|
||||
from langchain.agents.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
)
|
||||
|
||||
__all__ = ["FileManagementToolkit"]
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Toolkit for interacting with the local filesystem."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.file_management.copy import CopyFileTool
|
||||
from langchain.tools.file_management.delete import DeleteFileTool
|
||||
from langchain.tools.file_management.file_search import FileSearchTool
|
||||
from langchain.tools.file_management.list_dir import ListDirectoryTool
|
||||
from langchain.tools.file_management.move import MoveFileTool
|
||||
from langchain.tools.file_management.read import ReadFileTool
|
||||
from langchain.tools.file_management.write import WriteFileTool
|
||||
|
||||
_FILE_TOOLS = {
|
||||
tool_cls.__fields__["name"].default: tool_cls
|
||||
for tool_cls in [
|
||||
CopyFileTool,
|
||||
DeleteFileTool,
|
||||
FileSearchTool,
|
||||
MoveFileTool,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
ListDirectoryTool,
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class FileManagementToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with a Local Files."""
|
||||
|
||||
root_dir: Optional[str] = None
|
||||
"""If specified, all file operations are made relative to root_dir."""
|
||||
selected_tools: Optional[List[str]] = None
|
||||
"""If provided, only provide the selected tools. Defaults to all."""
|
||||
|
||||
@root_validator
|
||||
def validate_tools(cls, values: dict) -> dict:
|
||||
selected_tools = values.get("selected_tools") or []
|
||||
for tool_name in selected_tools:
|
||||
if tool_name not in _FILE_TOOLS:
|
||||
raise ValueError(
|
||||
f"File Tool of name {tool_name} not supported."
|
||||
f" Permitted tools: {list(_FILE_TOOLS)}"
|
||||
)
|
||||
return values
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
allowed_tools = self.selected_tools or _FILE_TOOLS.keys()
|
||||
tools: List[BaseTool] = []
|
||||
for tool in allowed_tools:
|
||||
tool_cls = _FILE_TOOLS[tool]
|
||||
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
|
||||
return tools
|
||||
|
||||
|
||||
__all__ = ["FileManagementToolkit"]
|
||||
@@ -28,13 +28,13 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
||||
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.memory import ReadOnlySharedMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.requests.tool import BaseRequestsTool
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def create_pandas_dataframe_agent(
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
|
||||
)
|
||||
partial_prompt = prompt.partial(df=str(df.head().to_markdown()))
|
||||
partial_prompt = prompt.partial(df=str(df.head()))
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=partial_prompt,
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""Playwright browser toolkit."""
|
||||
from langchain.agents.agent_toolkits.playwright.toolkit import PlayWrightBrowserToolkit
|
||||
|
||||
__all__ = ["PlayWrightBrowserToolkit"]
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Playwright web browser toolkit."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Type
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.playwright.base import BaseBrowserTool
|
||||
from langchain.tools.playwright.click import ClickTool
|
||||
from langchain.tools.playwright.current_page import CurrentWebPageTool
|
||||
from langchain.tools.playwright.extract_hyperlinks import ExtractHyperlinksTool
|
||||
from langchain.tools.playwright.extract_text import ExtractTextTool
|
||||
from langchain.tools.playwright.get_elements import GetElementsTool
|
||||
from langchain.tools.playwright.navigate import NavigateTool
|
||||
from langchain.tools.playwright.navigate_back import NavigateBackTool
|
||||
from langchain.tools.playwright.utils import create_playwright_browser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from playwright.async_api import Browser as AsyncBrowser
|
||||
|
||||
|
||||
class PlayWrightBrowserToolkit(BaseToolkit):
|
||||
"""Toolkit for web browser tools."""
|
||||
|
||||
browser: AsyncBrowser = Field(default_factory=create_playwright_browser)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def check_args(cls, values: dict) -> dict:
|
||||
"""Check that the arguments are valid."""
|
||||
try:
|
||||
from playwright.async_api import Browser as AsyncBrowser # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The 'playwright' package is required to use this tool."
|
||||
" Please install it with 'pip install playwright'."
|
||||
)
|
||||
return values
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
tool_classes: List[Type[BaseBrowserTool]] = [
|
||||
ClickTool,
|
||||
NavigateTool,
|
||||
NavigateBackTool,
|
||||
ExtractTextTool,
|
||||
ExtractHyperlinksTool,
|
||||
GetElementsTool,
|
||||
CurrentWebPageTool,
|
||||
]
|
||||
|
||||
return [tool_cls.from_browser(self.browser) for tool_cls in tool_classes]
|
||||
|
||||
@classmethod
|
||||
def from_browser(cls, browser: AsyncBrowser) -> PlayWrightBrowserToolkit:
|
||||
from playwright.async_api import Browser as AsyncBrowser
|
||||
|
||||
cls.update_forward_refs(AsyncBrowser=AsyncBrowser)
|
||||
return cls(browser=browser)
|
||||
@@ -4,10 +4,10 @@ from typing import List, Optional
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.powerbi.prompt import QUESTION_TO_QUERY
|
||||
from langchain.tools.powerbi.tool import (
|
||||
@@ -35,20 +35,24 @@ class PowerBIToolkit(BaseToolkit):
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
if self.callback_manager:
|
||||
chain = LLMChain(
|
||||
llm=self.llm,
|
||||
callback_manager=self.callback_manager,
|
||||
prompt=PromptTemplate(
|
||||
template=QUESTION_TO_QUERY,
|
||||
input_variables=["tool_input", "tables", "schemas", "examples"],
|
||||
chain = (
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
callback_manager=self.callback_manager,
|
||||
prompt=PromptTemplate(
|
||||
template=QUESTION_TO_QUERY,
|
||||
input_variables=["tool_input", "tables", "schemas", "examples"],
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
chain = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=PromptTemplate(
|
||||
template=QUESTION_TO_QUERY,
|
||||
input_variables=["tool_input", "tables", "schemas", "examples"],
|
||||
chain = (
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=PromptTemplate(
|
||||
template=QUESTION_TO_QUERY,
|
||||
input_variables=["tool_input", "tables", "schemas", "examples"],
|
||||
),
|
||||
),
|
||||
)
|
||||
return [
|
||||
@@ -56,8 +60,8 @@ class PowerBIToolkit(BaseToolkit):
|
||||
InfoPowerBITool(powerbi=self.powerbi),
|
||||
ListPowerBITool(powerbi=self.powerbi),
|
||||
InputToQueryTool(
|
||||
llm_chain=chain,
|
||||
powerbi=self.powerbi,
|
||||
llm_chain=chain,
|
||||
examples=self.examples,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import Field
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.chat.output_parser import ChatOutputParser
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -14,7 +13,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.schema import AgentAction, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from langchain.agents.conversational_chat.prompt import (
|
||||
SUFFIX,
|
||||
TEMPLATE_TOOL_RESPONSE,
|
||||
)
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -25,6 +24,7 @@ from langchain.prompts.chat import (
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
HumanMessage,
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any, Optional, Sequence
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from langchain.requests import TextRequestsWrapper
|
||||
from langchain.tools.arxiv.tool import ArxivQueryRun
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.bing_search.tool import BingSearchRun
|
||||
from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun
|
||||
from langchain.tools.ddg_search.tool import DuckDuckGoSearchTool
|
||||
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
|
||||
from langchain.tools.human.tool import HumanInputRun
|
||||
from langchain.tools.python.tool import PythonREPLTool
|
||||
@@ -27,7 +27,6 @@ from langchain.tools.requests.tool import (
|
||||
RequestsPutTool,
|
||||
)
|
||||
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
|
||||
from langchain.tools.shell.tool import ShellTool
|
||||
from langchain.tools.wikipedia.tool import WikipediaQueryRun
|
||||
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
|
||||
from langchain.utilities import ArxivAPIWrapper
|
||||
@@ -68,7 +67,11 @@ def _get_tools_requests_delete() -> BaseTool:
|
||||
|
||||
|
||||
def _get_terminal() -> BaseTool:
|
||||
return ShellTool()
|
||||
return Tool(
|
||||
name="Terminal",
|
||||
description="Executes commands in a terminal. Input should be valid commands, and the output will be any output from running that command.",
|
||||
func=BashProcess().run,
|
||||
)
|
||||
|
||||
|
||||
_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
|
||||
@@ -103,8 +106,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool:
|
||||
return Tool(
|
||||
name="Calculator",
|
||||
description="Useful for when you need to answer questions about math.",
|
||||
func=LLMMathChain(llm=llm).run,
|
||||
coroutine=LLMMathChain(llm=llm).arun,
|
||||
func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run,
|
||||
coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun,
|
||||
)
|
||||
|
||||
|
||||
@@ -216,7 +219,7 @@ def _get_bing_search(**kwargs: Any) -> BaseTool:
|
||||
|
||||
|
||||
def _get_ddg_search(**kwargs: Any) -> BaseTool:
|
||||
return DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper(**kwargs))
|
||||
return DuckDuckGoSearchTool(api_wrapper=DuckDuckGoSearchAPIWrapper(**kwargs))
|
||||
|
||||
|
||||
def _get_human_tool(**kwargs: Any) -> BaseTool:
|
||||
|
||||
@@ -10,10 +10,10 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@@ -122,7 +122,6 @@ class ZeroShotAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
for tool in tools:
|
||||
if tool.description is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -37,7 +37,6 @@ class ReActDocstoreAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 2:
|
||||
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
@@ -120,7 +119,6 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
|
||||
@@ -36,7 +36,6 @@ class SelfAskWithSearchAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Interface for tools."""
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, validate_arguments, validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
from langchain.tools.base import (
|
||||
BaseTool,
|
||||
create_schema_from_function,
|
||||
get_filtered_args,
|
||||
)
|
||||
from langchain.tools.base import BaseTool, StructuredTool
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
@@ -30,68 +30,25 @@ class Tool(BaseTool):
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The tool's input arguments."""
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
# For backwards compatibility, if the function signature is ambiguous,
|
||||
# assume it takes a single string input.
|
||||
return {"tool_input": {"type": "string"}}
|
||||
else:
|
||||
inferred_model = validate_arguments(self.func).model # type: ignore
|
||||
return get_filtered_args(inferred_model, self.func)
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
args, kwargs = super()._to_args_and_kwargs(tool_input)
|
||||
# For backwards compatibility. The tool must be run with a single input
|
||||
all_args = list(args) + list(kwargs.values())
|
||||
if len(all_args) != 1:
|
||||
raise ValueError(
|
||||
f"Too many arguments to single-input tool {self.name}."
|
||||
f" Args: {all_args}"
|
||||
)
|
||||
return tuple(all_args), {}
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Use the tool."""
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
raise NotImplementedError("Tool does not support async")
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
def __init__(
|
||||
self, name: str, func: Callable, description: str, **kwargs: Any
|
||||
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize tool."""
|
||||
super(Tool, self).__init__(
|
||||
@@ -105,17 +62,11 @@ class InvalidTool(BaseTool):
|
||||
name = "invalid_tool"
|
||||
description = "Called when tool name is invalid."
|
||||
|
||||
def _run(
|
||||
self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
def _run(self, tool_name: str) -> str:
|
||||
"""Use the tool."""
|
||||
return f"{tool_name} is not a valid tool, try another one."
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_name: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
async def _arun(self, tool_name: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
return f"{tool_name} is not a valid tool, try another one."
|
||||
|
||||
@@ -156,24 +107,22 @@ def tool(
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(func: Callable) -> BaseTool:
|
||||
if infer_schema or args_schema is not None:
|
||||
return StructuredTool.from_function(
|
||||
func,
|
||||
name=tool_name,
|
||||
return_direct=return_direct,
|
||||
args_schema=args_schema,
|
||||
infer_schema=infer_schema,
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
assert func.__doc__ is not None, "Function must have a docstring"
|
||||
return Tool(
|
||||
def _make_tool(func: Callable) -> Tool:
|
||||
assert func.__doc__, "Function must have a docstring"
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
|
||||
tool_ = Tool(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
description=f"{tool_name} tool",
|
||||
args_schema=_args_schema,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
)
|
||||
return tool_
|
||||
|
||||
return _make_tool
|
||||
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Base class for all language models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the message."""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Beta Feature: base interface for cache."""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||
from sqlalchemy.engine.base import Engine
|
||||
@@ -28,10 +28,6 @@ class BaseCache(ABC):
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache that can take additional keyword arguments."""
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
"""Cache that stores things in memory."""
|
||||
@@ -48,10 +44,6 @@ class InMemoryCache(BaseCache):
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self._cache[(prompt, llm_string)] = return_val
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
self._cache = {}
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -69,7 +61,7 @@ class FullLLMCache(Base): # type: ignore
|
||||
class SQLAlchemyCache(BaseCache):
|
||||
"""Cache that uses SQAlchemy as a backend."""
|
||||
|
||||
def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache):
|
||||
def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache):
|
||||
"""Initialize by creating all tables."""
|
||||
self.engine = engine
|
||||
self.cache_schema = cache_schema
|
||||
@@ -84,26 +76,20 @@ class SQLAlchemyCache(BaseCache):
|
||||
.order_by(self.cache_schema.idx)
|
||||
)
|
||||
with Session(self.engine) as session:
|
||||
rows = session.execute(stmt).fetchall()
|
||||
if rows:
|
||||
return [Generation(text=row[0]) for row in rows]
|
||||
generations = [Generation(text=row[0]) for row in session.execute(stmt)]
|
||||
if len(generations) > 0:
|
||||
return generations
|
||||
return None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update based on prompt and llm_string."""
|
||||
items = [
|
||||
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
|
||||
for i, gen in enumerate(return_val)
|
||||
]
|
||||
with Session(self.engine) as session, session.begin():
|
||||
for item in items:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
for i, generation in enumerate(return_val):
|
||||
item = self.cache_schema(
|
||||
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
||||
)
|
||||
with Session(self.engine) as session, session.begin():
|
||||
session.merge(item)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
with Session(self.engine) as session:
|
||||
session.execute(self.cache_schema.delete())
|
||||
|
||||
|
||||
class SQLiteCache(SQLAlchemyCache):
|
||||
"""Cache that uses SQLite as a backend."""
|
||||
@@ -153,26 +139,19 @@ class RedisCache(BaseCache):
|
||||
for i, generation in enumerate(return_val):
|
||||
self.redis.set(self._key(prompt, llm_string, i), generation.text)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
|
||||
asynchronous = kwargs.get("asynchronous", False)
|
||||
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
|
||||
|
||||
|
||||
class GPTCache(BaseCache):
|
||||
"""Cache that uses GPTCache as a backend."""
|
||||
|
||||
def __init__(self, init_func: Optional[Callable[[Any], None]] = None):
|
||||
"""Initialize by passing in init function (default: `None`).
|
||||
def __init__(self, init_func: Callable[[Any], None]):
|
||||
"""Initialize by passing in the `init` GPTCache func
|
||||
|
||||
Args:
|
||||
init_func (Optional[Callable[[Any], None]]): init `GPTCache` function
|
||||
(default: `None`)
|
||||
init_func (Callable[[Any], None]): init `GPTCache` function
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Initialize GPTCache with a custom init function
|
||||
import gptcache
|
||||
from gptcache.processor.pre import get_prompt
|
||||
from gptcache.manager.factory import get_data_manager
|
||||
@@ -201,8 +180,7 @@ class GPTCache(BaseCache):
|
||||
"Could not import gptcache python package. "
|
||||
"Please install it with `pip install gptcache`."
|
||||
)
|
||||
|
||||
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
|
||||
self.init_gptcache_func: Callable[[Any], None] = init_func
|
||||
self.gptcache_dict: Dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
@@ -227,19 +205,11 @@ class GPTCache(BaseCache):
|
||||
|
||||
When the corresponding llm model cache does not exist, it will be created."""
|
||||
from gptcache import Cache
|
||||
from gptcache.manager.factory import get_data_manager
|
||||
from gptcache.processor.pre import get_prompt
|
||||
|
||||
_gptcache = self.gptcache_dict.get(llm_string, None)
|
||||
if _gptcache is None:
|
||||
_gptcache = Cache()
|
||||
if self.init_gptcache_func is not None:
|
||||
self.init_gptcache_func(_gptcache)
|
||||
else:
|
||||
_gptcache.init(
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=llm_string),
|
||||
)
|
||||
self.init_gptcache_func(_gptcache)
|
||||
self.gptcache_dict[llm_string] = _gptcache
|
||||
return _gptcache
|
||||
|
||||
@@ -250,7 +220,7 @@ class GPTCache(BaseCache):
|
||||
"""
|
||||
from gptcache.adapter.adapter import adapt
|
||||
|
||||
_gptcache = self.gptcache_dict.get(llm_string, None)
|
||||
_gptcache = self.gptcache_dict.get(llm_string)
|
||||
if _gptcache is None:
|
||||
return None
|
||||
res = adapt(
|
||||
@@ -264,10 +234,7 @@ class GPTCache(BaseCache):
|
||||
|
||||
@staticmethod
|
||||
def _update_cache_callback(
|
||||
llm_data: RETURN_VAL_TYPE,
|
||||
update_cache_func: Callable[[Any], None],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
llm_data: RETURN_VAL_TYPE, update_cache_func: Callable[[Any], None]
|
||||
) -> None:
|
||||
"""Save the `llm_data` to cache storage"""
|
||||
handled_data = json.dumps([generation.dict() for generation in llm_data])
|
||||
@@ -293,13 +260,3 @@ class GPTCache(BaseCache):
|
||||
cache_skip=True,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
from gptcache import Cache
|
||||
|
||||
for gptcache_instance in self.gptcache_dict.values():
|
||||
gptcache_instance = cast(Cache, gptcache_instance)
|
||||
gptcache_instance.flush()
|
||||
|
||||
self.gptcache_dict.clear()
|
||||
|
||||
@@ -1,27 +1,80 @@
|
||||
"""Callback handlers that allow listening to events in LangChain."""
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from typing import Generator, Optional
|
||||
|
||||
from langchain.callbacks.aim_callback import AimCallbackHandler
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
|
||||
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
get_openai_callback,
|
||||
tracing_enabled,
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.tracers import LangChainTracer
|
||||
from langchain.callbacks.tracers import SharedLangChainTracer
|
||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
"""Return the shared callback manager."""
|
||||
return SharedCallbackManager()
|
||||
|
||||
|
||||
def set_handler(handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler."""
|
||||
callback = get_callback_manager()
|
||||
callback.set_handler(handler)
|
||||
|
||||
|
||||
def set_default_callback_manager() -> None:
|
||||
"""Set default callback manager."""
|
||||
default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
|
||||
if default_handler == "stdout":
|
||||
set_handler(StdOutCallbackHandler())
|
||||
elif default_handler == "langchain":
|
||||
session = os.environ.get("LANGCHAIN_SESSION")
|
||||
set_tracing_callback_manager(session)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LANGCHAIN_HANDLER should be one of `stdout` "
|
||||
f"or `langchain`, got {default_handler}"
|
||||
)
|
||||
|
||||
|
||||
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
|
||||
"""Set tracing callback manager."""
|
||||
handler = SharedLangChainTracer()
|
||||
callback = get_callback_manager()
|
||||
callback.set_handlers([handler, StdOutCallbackHandler()])
|
||||
if session_name is None:
|
||||
handler.load_default_session()
|
||||
else:
|
||||
try:
|
||||
handler.load_session(session_name)
|
||||
except Exception:
|
||||
raise ValueError(f"session {session_name} not found")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
handler = OpenAICallbackHandler()
|
||||
manager = get_callback_manager()
|
||||
manager.add_handler(handler)
|
||||
yield handler
|
||||
manager.remove_handler(handler)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CallbackManager",
|
||||
"AsyncCallbackManager",
|
||||
"OpenAICallbackHandler",
|
||||
"SharedCallbackManager",
|
||||
"StdOutCallbackHandler",
|
||||
"AimCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
@@ -29,5 +82,8 @@ __all__ = [
|
||||
"CometCallbackHandler",
|
||||
"AsyncIteratorCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"tracing_enabled",
|
||||
"set_tracing_callback_manager",
|
||||
"set_default_callback_manager",
|
||||
"set_handler",
|
||||
"get_callback_manager",
|
||||
]
|
||||
|
||||
@@ -1,174 +1,20 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class LLMManagerMixin:
|
||||
"""Mixin for LLM callbacks."""
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class ToolManagerMixin:
|
||||
"""Mixin for tool callbacks."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
|
||||
class CallbackManagerMixin:
|
||||
"""Mixin for callback manager."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
|
||||
class RunManagerMixin:
|
||||
"""Mixin for run manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
class BaseCallbackHandler(ABC):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
@@ -184,197 +30,480 @@ class BaseCallbackHandler(
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@abstractmethod
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[str] = parent_run_id
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
return False
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
@abstractmethod
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
|
||||
@abstractmethod
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.set_handlers([handler])
|
||||
|
||||
@abstractmethod
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_end(response)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_end(outputs)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_error(error)
|
||||
|
||||
def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
for handler in self.handlers:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_agent_finish(finish, **kwargs)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
if inherit:
|
||||
self.inheritable_handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
self.add_handler(handler, inherit=inherit)
|
||||
self.handlers = handlers
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
def __copy__(self) -> "BaseCallbackManager":
|
||||
return self.__class__(
|
||||
self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
async def _handle_event_for_handler(
|
||||
handler: BaseCallbackHandler,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
verbose: bool,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
if verbose or handler.always_verbose:
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(event, *args, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the handler is async."""
|
||||
return True
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
async def _handle_event(
|
||||
self,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
verbose: bool,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Generic event handler for AsyncCallbackManager."""
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_handle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
|
||||
)
|
||||
for handler in self.handlers
|
||||
)
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "BaseCallbackManager":
|
||||
return self.__class__(
|
||||
[copy.deepcopy(handler, memo) for handler in self.handlers],
|
||||
[copy.deepcopy(handler, memo) for handler in self.inheritable_handlers],
|
||||
self.parent_run_id,
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
await self._handle_event(
|
||||
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
await self._handle_event(
|
||||
"on_llm_new_token", "ignore_llm", verbose, token, **kwargs
|
||||
)
|
||||
|
||||
async def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
await self._handle_event(
|
||||
"on_llm_end", "ignore_llm", verbose, response, **kwargs
|
||||
)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
await self._handle_event(
|
||||
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
|
||||
)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
await self._handle_event(
|
||||
"on_chain_end", "ignore_chain", verbose, outputs, **kwargs
|
||||
)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
await self._handle_event(
|
||||
"on_chain_error", "ignore_chain", verbose, error, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
await self._handle_event(
|
||||
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self, output: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
await self._handle_event(
|
||||
"on_tool_end", "ignore_agent", verbose, output, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
await self._handle_event(
|
||||
"on_tool_error", "ignore_agent", verbose, error, **kwargs
|
||||
)
|
||||
|
||||
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when text is printed."""
|
||||
await self._handle_event("on_text", None, verbose, text, **kwargs)
|
||||
|
||||
async def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
await self._handle_event(
|
||||
"on_agent_action", "ignore_agent", verbose, action, **kwargs
|
||||
)
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when agent finishes."""
|
||||
await self._handle_event(
|
||||
"on_agent_finish", "ignore_agent", verbose, finish, **kwargs
|
||||
)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
@@ -1,731 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, TypeVar, Union
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.base import TracerSession
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
|
||||
"tracing_callback", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
cb = OpenAICallbackHandler()
|
||||
openai_callback_var.set(cb)
|
||||
yield cb
|
||||
openai_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
cb = LangChainTracer()
|
||||
session = cb.load_session(session_name)
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
tracing_callback_var.set(None)
|
||||
|
||||
|
||||
def _handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
getattr(handler, event_name)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# TODO: switch this to use logging
|
||||
print(f"Error in {event_name} callback: {e}")
|
||||
|
||||
|
||||
async def _ahandle_event_for_handler(
|
||||
handler: BaseCallbackHandler,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(event, *args, **kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: switch this to use logging
|
||||
print(f"Error in {event_name} callback: {e}")
|
||||
|
||||
|
||||
async def _ahandle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Generic event handler for AsyncCallbackManager."""
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_ahandle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, *args, **kwargs
|
||||
)
|
||||
for handler in handlers
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
BRM = TypeVar("BRM", bound="BaseRunManager")
|
||||
|
||||
|
||||
class BaseRunManager(RunManagerMixin):
|
||||
"""Base class for run manager (a bound callback manager)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler],
|
||||
parent_run_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize run manager."""
|
||||
self.run_id = run_id
|
||||
self.handlers = handlers
|
||||
self.inheritable_handlers = inheritable_handlers
|
||||
self.parent_run_id = parent_run_id
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||
"""Return a manager that doesn't perform any operations."""
|
||||
return cls("", [], [])
|
||||
|
||||
|
||||
class RunManager(BaseRunManager):
|
||||
"""Sync Run Manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when text is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
None,
|
||||
text,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncRunManager(BaseRunManager):
|
||||
"""Async Run Manager."""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when text is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
None,
|
||||
text,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
"""Callback manager for LLM run."""
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token=token,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
"ignore_llm",
|
||||
response,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
"ignore_llm",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Async callback manager for LLM run."""
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
"ignore_llm",
|
||||
response,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
"ignore_llm",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def get_child(self) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
"ignore_chain",
|
||||
outputs,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
"ignore_chain",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
"ignore_agent",
|
||||
action,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
"ignore_agent",
|
||||
finish,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
||||
"""Async callback manager for chain run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
"ignore_chain",
|
||||
outputs,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
"ignore_chain",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
"ignore_agent",
|
||||
action,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
"ignore_agent",
|
||||
finish,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
||||
"""Callback manager for tool run."""
|
||||
|
||||
def get_child(self) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
"ignore_agent",
|
||||
output,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
"ignore_agent",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
"""Async callback manager for tool run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
"ignore_agent",
|
||||
output,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
"ignore_agent",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForChainRun:
|
||||
"""Run when chain starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_start",
|
||||
"ignore_chain",
|
||||
serialized,
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForChainRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: Optional[str] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForToolRun:
|
||||
"""Run when tool starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_start",
|
||||
"ignore_agent",
|
||||
serialized,
|
||||
input_str,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForToolRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
) -> CallbackManager:
|
||||
"""Configure the callback manager."""
|
||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the handler is async."""
|
||||
return True
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForLLMRun:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForChainRun:
|
||||
"""Run when chain starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_start",
|
||||
"ignore_chain",
|
||||
serialized,
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForChainRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: Optional[str] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForToolRun:
|
||||
"""Run when tool starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_start",
|
||||
"ignore_agent",
|
||||
serialized,
|
||||
input_str,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForToolRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Configure the callback manager."""
|
||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
||||
|
||||
|
||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
) -> T:
|
||||
"""Configure the callback manager."""
|
||||
callback_manager = callback_manager_cls([])
|
||||
if inheritable_callbacks or local_callbacks:
|
||||
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
||||
inheritable_callbacks_ = inheritable_callbacks or []
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=inheritable_callbacks_,
|
||||
inheritable_handlers=inheritable_callbacks_,
|
||||
)
|
||||
else:
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=inheritable_callbacks.handlers,
|
||||
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
|
||||
parent_run_id=inheritable_callbacks.parent_run_id,
|
||||
)
|
||||
callback_manager = copy.deepcopy(callback_manager)
|
||||
local_handlers_ = (
|
||||
local_callbacks
|
||||
if isinstance(local_callbacks, list)
|
||||
else (local_callbacks.handlers if local_callbacks else [])
|
||||
)
|
||||
for handler in local_handlers_:
|
||||
callback_manager.add_handler(copy.deepcopy(handler), False)
|
||||
|
||||
tracer = tracing_callback_var.get()
|
||||
open_ai = openai_callback_var.get()
|
||||
tracing_enabled_ = (
|
||||
os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None
|
||||
)
|
||||
if verbose or tracing_enabled_ or open_ai is not None:
|
||||
if verbose and not any(
|
||||
isinstance(handler, StdOutCallbackHandler)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(StdOutCallbackHandler(), False)
|
||||
|
||||
if tracing_enabled_ and not any(
|
||||
isinstance(handler, LangChainTracer)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
if tracer:
|
||||
callback_manager.add_handler(copy.deepcopy(tracer), True)
|
||||
else:
|
||||
handler = LangChainTracer()
|
||||
handler.load_default_session()
|
||||
callback_manager.add_handler(handler, True)
|
||||
if open_ai is not None and not any(
|
||||
isinstance(handler, OpenAICallbackHandler)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(open_ai, True)
|
||||
|
||||
return callback_manager
|
||||
127
langchain/callbacks/shared.py
Normal file
127
langchain/callbacks/shared.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""A shared CallbackManager."""
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class Singleton:
|
||||
"""A thread-safe singleton class that can be inherited from."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> Any:
|
||||
"""Create a new shared instance of the class."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
# Another thread could have created the instance
|
||||
# before we acquired the lock. So check that the
|
||||
# instance is still nonexistent.
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
"""A thread-safe singleton CallbackManager."""
|
||||
|
||||
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_end(response, **kwargs)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_error(error, **kwargs)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_end(outputs, **kwargs)
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_error(error, **kwargs)
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_finish(finish, **kwargs)
|
||||
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a callback to the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.add_handler(callback)
|
||||
|
||||
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Remove a callback from the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.remove_handler(callback)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.handlers = handlers
|
||||
@@ -1,5 +1,12 @@
|
||||
"""Tracers that record execution of LangChain runs."""
|
||||
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.base import SharedTracer, Tracer
|
||||
from langchain.callbacks.tracers.langchain import BaseLangChainTracer
|
||||
|
||||
__all__ = ["LangChainTracer"]
|
||||
|
||||
class SharedLangChainTracer(SharedTracer, BaseLangChainTracer):
|
||||
"""Shared tracer that records LangChain execution to LangChain endpoint."""
|
||||
|
||||
|
||||
class LangChainTracer(Tracer, BaseLangChainTracer):
|
||||
"""Tracer that records LangChain execution to LangChain endpoint."""
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Base interfaces for tracing runs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.shared import Singleton
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
@@ -14,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
@@ -24,25 +26,13 @@ class TracerException(Exception):
|
||||
class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Base interface for tracers."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
||||
self.session: Optional[TracerSession] = None
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
if isinstance(child_run, LLMRun):
|
||||
parent_run.child_llm_runs.append(child_run)
|
||||
elif isinstance(child_run, ChainRun):
|
||||
parent_run.child_chain_runs.append(child_run)
|
||||
elif isinstance(child_run, ToolRun):
|
||||
parent_run.child_tool_runs.append(child_run)
|
||||
else:
|
||||
raise TracerException(f"Invalid run type: {type(child_run)}")
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
@@ -52,11 +42,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
|
||||
@abstractmethod
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
|
||||
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
|
||||
"""NOT thread safe, do not call this method from multiple threads."""
|
||||
session_create = TracerSessionCreate(name=name, extra=kwargs)
|
||||
session = self._persist_session(session_create)
|
||||
self.session = session
|
||||
self._session = session
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
@@ -67,232 +61,283 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
|
||||
@_execution_order.setter
|
||||
@abstractmethod
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
|
||||
@_session.setter
|
||||
@abstractmethod
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
|
||||
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Start a trace for a run."""
|
||||
if run.parent_uuid:
|
||||
parent_run = self.run_map[run.parent_uuid]
|
||||
if parent_run:
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException(
|
||||
"Cannot add child run to an LLM run. "
|
||||
"LLM runs are not allowed to have children."
|
||||
)
|
||||
self._add_child_run(parent_run, run)
|
||||
else:
|
||||
self._execution_order += 1
|
||||
|
||||
if self._stack:
|
||||
if not (
|
||||
isinstance(self._stack[-1], ChainRun)
|
||||
or isinstance(self._stack[-1], ToolRun)
|
||||
):
|
||||
raise TracerException(
|
||||
f"Parent run with UUID {run.parent_uuid} not found."
|
||||
f"Nested {run.__class__.__name__} can only be"
|
||||
f" logged inside a ChainRun or ToolRun"
|
||||
)
|
||||
self._add_child_run(self._stack[-1], run)
|
||||
self._stack.append(run)
|
||||
|
||||
self.run_map[run.uuid] = run
|
||||
|
||||
def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
def _end_trace(self) -> None:
|
||||
"""End a trace for a run."""
|
||||
if not run.parent_uuid:
|
||||
run = self._stack.pop()
|
||||
if not self._stack:
|
||||
self._execution_order = 1
|
||||
self._persist_run(run)
|
||||
else:
|
||||
parent_run = self.run_map.get(run.parent_uuid)
|
||||
if parent_run is None:
|
||||
raise TracerException(
|
||||
f"Parent run with UUID {run.parent_uuid} not found."
|
||||
)
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
||||
if run.child_execution_order > parent_run.child_execution_order:
|
||||
parent_run.child_execution_order = run.child_execution_order
|
||||
self.run_map.pop(run.uuid)
|
||||
|
||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
if parent_run_id is None:
|
||||
return 1
|
||||
|
||||
parent_run = self.run_map.get(parent_run_id)
|
||||
if parent_run is None:
|
||||
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
|
||||
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
||||
|
||||
return parent_run.child_execution_order + 1
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
|
||||
if run_id is None:
|
||||
run_id = str(uuid4())
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
llm_run = LLMRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
prompts=prompts,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
session_id=self.session.id,
|
||||
execution_order=self._execution_order,
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: str, **kwargs: Any) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Handle a new token for an LLM run."""
|
||||
pass
|
||||
|
||||
llm_run = self.run_map.get(run_id)
|
||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
llm_run.response = response
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
self._end_trace(llm_run)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].response = response
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Handle an error for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
||||
|
||||
llm_run = self.run_map.get(run_id)
|
||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
llm_run.error = repr(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
self._end_trace(llm_run)
|
||||
self._stack[-1].error = repr(error)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for a chain run."""
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
chain_run = ChainRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
execution_order=self._execution_order,
|
||||
child_runs=[],
|
||||
session_id=self.session.id,
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], *, run_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""End a trace for a chain run."""
|
||||
chain_run = self.run_map.get(run_id)
|
||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
chain_run.outputs = outputs
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
self._end_trace(chain_run)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].outputs = outputs
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Handle an error for a chain run."""
|
||||
chain_run = self.run_map.get(run_id)
|
||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
chain_run.error = repr(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
self._end_trace(chain_run)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for a tool run."""
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
tool_run = ToolRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
# TODO: this is duplicate info as above, not needed.
|
||||
action=str(serialized),
|
||||
tool_input=input_str,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
execution_order=self._execution_order,
|
||||
child_runs=[],
|
||||
session_id=self.session.id,
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""End a trace for a tool run."""
|
||||
tool_run = self.run_map.get(run_id)
|
||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
tool_run.output = output
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
self._end_trace(tool_run)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].output = output
|
||||
|
||||
self._end_trace()
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Handle an error for a tool run."""
|
||||
tool_run = self.run_map.get(run_id)
|
||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
tool_run.error = repr(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
self._end_trace(tool_run)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
self._end_trace()
|
||||
|
||||
def __copy__(self) -> BaseTracer:
|
||||
"""Copy the tracer."""
|
||||
return self
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Handle a text message."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Handle an agent finish message."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class Tracer(BaseTracer, ABC):
|
||||
"""A non-thread safe implementation of the BaseTracer interface."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a tracer."""
|
||||
self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
self._tracer_execution_order = 1
|
||||
self._tracer_session: Optional[TracerSession] = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TracerStack(threading.local):
|
||||
"""A stack of runs used for logging."""
|
||||
|
||||
stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list)
|
||||
execution_order: int = 1
|
||||
|
||||
|
||||
class SharedTracer(Singleton, BaseTracer, ABC):
|
||||
"""A thread-safe Singleton implementation of BaseTracer."""
|
||||
|
||||
_tracer_stack = TracerStack()
|
||||
_tracer_session = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack.stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_stack.execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_stack.execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
with self._lock:
|
||||
# TODO: currently, we are only checking current thread's stack.
|
||||
# Need to make sure that we are not in the middle of a trace
|
||||
# in any thread.
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
@@ -17,17 +18,14 @@ from langchain.callbacks.tracers.schemas import (
|
||||
)
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
class BaseLangChainTracer(BaseTracer, ABC):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
self.session = self.load_session(session_name)
|
||||
always_verbose: bool = True
|
||||
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
@@ -61,29 +59,54 @@ class LangChainTracer(BaseTracer):
|
||||
session = TracerSession(id=1, **session_create.dict())
|
||||
return session
|
||||
|
||||
def _load_session(self, session_name: Optional[str] = None) -> TracerSession:
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a session from the tracer."""
|
||||
try:
|
||||
url = f"{self._endpoint}/sessions"
|
||||
if session_name:
|
||||
url += f"?name={session_name}"
|
||||
r = requests.get(url, headers=self._headers)
|
||||
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions?name={session_name}",
|
||||
headers=self._headers,
|
||||
)
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logging.warning(
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
f"Failed to load session {session_name}, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSession(id=1)
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
return self._load_session(session_name)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self._load_session("default")
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions",
|
||||
headers=self._headers,
|
||||
)
|
||||
# Use the first session result
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to default session, using empty session: {e}")
|
||||
tracer_session = TracerSession(id=1)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
if isinstance(child_run, LLMRun):
|
||||
parent_run.child_llm_runs.append(child_run)
|
||||
elif isinstance(child_run, ChainRun):
|
||||
parent_run.child_chain_runs.append(child_run)
|
||||
else:
|
||||
parent_run.child_tool_runs.append(child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
||||
|
||||
@@ -32,13 +32,11 @@ class TracerSession(TracerSessionBase):
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
uuid: str
|
||||
parent_uuid: Optional[str] = None
|
||||
id: Optional[Union[int, str]] = None
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
serialized: Dict[str, Any]
|
||||
session_id: int
|
||||
error: Optional[str] = None
|
||||
@@ -59,6 +57,7 @@ class ChainRun(BaseRun):
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolRun(BaseRun):
|
||||
@@ -70,6 +69,7 @@ class ToolRun(BaseRun):
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
|
||||
@@ -5,16 +5,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class APIChain(Chain):
|
||||
@@ -65,21 +61,16 @@ class APIChain(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
question = inputs[self.question_key]
|
||||
api_url = self.api_request_chain.predict(
|
||||
question=question,
|
||||
api_docs=self.api_docs,
|
||||
callbacks=_run_manager.get_child(),
|
||||
question=question, api_docs=self.api_docs
|
||||
)
|
||||
self.callback_manager.on_text(
|
||||
api_url, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
|
||||
api_response = self.requests_wrapper.get(api_url)
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
api_response, color="yellow", end="\n", verbose=self.verbose
|
||||
)
|
||||
answer = self.api_answer_chain.predict(
|
||||
@@ -87,27 +78,19 @@ class APIChain(Chain):
|
||||
api_docs=self.api_docs,
|
||||
api_url=api_url,
|
||||
api_response=api_response,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
question = inputs[self.question_key]
|
||||
api_url = await self.api_request_chain.apredict(
|
||||
question=question,
|
||||
api_docs=self.api_docs,
|
||||
callbacks=_run_manager.get_child(),
|
||||
question=question, api_docs=self.api_docs
|
||||
)
|
||||
await _run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
api_url, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
api_response = await self.requests_wrapper.aget(api_url)
|
||||
await _run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
api_response, color="yellow", end="\n", verbose=self.verbose
|
||||
)
|
||||
answer = await self.api_answer_chain.apredict(
|
||||
@@ -115,7 +98,6 @@ class APIChain(Chain):
|
||||
api_docs=self.api_docs,
|
||||
api_url=api_url,
|
||||
api_response=api_response,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return {self.output_key: answer}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast
|
||||
from pydantic import BaseModel, Field
|
||||
from requests import Response
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
|
||||
from langchain.chains.api.openapi.response_chain import APIResponderChain
|
||||
from langchain.chains.base import Chain
|
||||
@@ -107,21 +106,16 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
else:
|
||||
return {self.output_key: output}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
intermediate_steps = {}
|
||||
instructions = inputs[self.instructions_key]
|
||||
instructions = instructions[: self.max_text_length]
|
||||
_api_arguments = self.api_request_chain.predict_and_parse(
|
||||
instructions=instructions, callbacks=_run_manager.get_child()
|
||||
instructions=instructions
|
||||
)
|
||||
api_arguments = cast(str, _api_arguments)
|
||||
intermediate_steps["request_args"] = api_arguments
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
api_arguments, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
if api_arguments.startswith("ERROR"):
|
||||
@@ -147,17 +141,18 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
response_text = f"Error with message {str(e)}"
|
||||
response_text = response_text[: self.max_text_length]
|
||||
intermediate_steps["response_text"] = response_text
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
response_text, color="blue", end="\n", verbose=self.verbose
|
||||
)
|
||||
if self.api_response_chain is not None:
|
||||
_answer = self.api_response_chain.predict_and_parse(
|
||||
response=response_text,
|
||||
instructions=instructions,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
answer = cast(str, _answer)
|
||||
_run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose)
|
||||
self.callback_manager.on_text(
|
||||
answer, color="yellow", end="\n", verbose=self.verbose
|
||||
)
|
||||
return self._get_output(answer, intermediate_steps)
|
||||
else:
|
||||
return self._get_output(response_text, intermediate_steps)
|
||||
@@ -193,7 +188,6 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
raw_response: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
# TODO: Handle async
|
||||
) -> "OpenAPIEndpointChain":
|
||||
@@ -204,17 +198,12 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
path_params=operation.path_params,
|
||||
)
|
||||
requests_chain = APIRequesterChain.from_llm_and_typescript(
|
||||
llm,
|
||||
typescript_definition=operation.to_typescript(),
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
llm, typescript_definition=operation.to_typescript(), verbose=verbose
|
||||
)
|
||||
if raw_response:
|
||||
response_chain = None
|
||||
else:
|
||||
response_chain = APIResponderChain.from_llm(
|
||||
llm, verbose=verbose, callbacks=callbacks
|
||||
)
|
||||
response_chain = APIResponderChain.from_llm(llm, verbose=verbose)
|
||||
_requests = requests or Requests()
|
||||
return cls(
|
||||
api_request_chain=requests_chain,
|
||||
@@ -224,6 +213,5 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
param_mapping=param_mapping,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -37,11 +36,7 @@ class APIRequesterChain(LLMChain):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_typescript(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
typescript_definition: str,
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
cls, llm: BaseLLM, typescript_definition: str, verbose: bool = True
|
||||
) -> LLMChain:
|
||||
"""Get the request parser."""
|
||||
output_parser = APIRequesterOutputParser()
|
||||
@@ -51,4 +46,4 @@ class APIRequesterChain(LLMChain):
|
||||
partial_variables={"schema": typescript_definition},
|
||||
input_variables=["instructions"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -36,7 +35,7 @@ class APIResponderChain(LLMChain):
|
||||
"""Get the response parser."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain:
|
||||
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
|
||||
"""Get the response parser."""
|
||||
output_parser = APIResponderOutputParser()
|
||||
prompt = PromptTemplate(
|
||||
@@ -44,4 +43,4 @@ class APIResponderChain(LLMChain):
|
||||
output_parser=output_parser,
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
||||
|
||||
@@ -1,23 +1,15 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import BaseMemory
|
||||
|
||||
|
||||
@@ -29,8 +21,9 @@ class Chain(BaseModel, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
callback_manager: BaseCallbackManager = Field(
|
||||
default_factory=get_callback_manager, exclude=True
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
@@ -44,16 +37,15 @@ class Chain(BaseModel, ABC):
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
@@ -90,26 +82,15 @@ class Chain(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
raise NotImplementedError("Async call not supported for this chain type.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@@ -123,31 +104,21 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
outputs = self._call(inputs)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@@ -161,24 +132,30 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
outputs = await self._acall(inputs)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
def prep_outputs(
|
||||
@@ -218,13 +195,11 @@ class Chain(BaseModel, ABC):
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
return [self(inputs) for inputs in input_list]
|
||||
|
||||
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
def run(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
@@ -235,17 +210,17 @@ class Chain(BaseModel, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0], callbacks=callbacks)[self.output_keys[0]]
|
||||
return self(args[0])[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks)[self.output_keys[0]]
|
||||
return self(kwargs)[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
async def arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
@@ -256,10 +231,10 @@ class Chain(BaseModel, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]]
|
||||
return (await self.acall(args[0]))[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]]
|
||||
return (await self.acall(kwargs))[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
|
||||
@@ -5,10 +5,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -72,33 +68,19 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string asynchronously."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = self.combine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
output, extra_return_dict = self.combine_docs(docs, **other_keys)
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = await self.acombine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
output, extra_return_dict = await self.acombine_docs(docs, **other_keys)
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
@@ -126,17 +108,10 @@ class AnalyzeDocumentChain(Chain):
|
||||
"""
|
||||
return self.combine_docs_chain.output_keys
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
document = inputs[self.input_key]
|
||||
docs = self.text_splitter.create_documents([document])
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys[self.combine_docs_chain.input_key] = docs
|
||||
return self.combine_docs_chain(
|
||||
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
|
||||
)
|
||||
return self.combine_docs_chain(other_keys, return_only_outputs=True)
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -130,11 +129,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
return self.combine_document_chain
|
||||
|
||||
def combine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
@@ -143,15 +138,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""
|
||||
results = self.llm_chain.apply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return self._process_results(
|
||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
return self._process_results(results, docs, token_max, **kwargs)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
@@ -160,17 +152,15 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""
|
||||
results = await self.llm_chain.aapply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
return self._process_results(results, docs, callbacks=callbacks, **kwargs)
|
||||
return self._process_results(results, docs, **kwargs)
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
results: List[Dict],
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
question_result_key = self.llm_chain.output_key
|
||||
@@ -183,9 +173,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
return self._collapse_chain.run(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return self._collapse_chain.run(input_documents=docs, **kwargs)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
@@ -203,9 +191,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
extra_return_dict = {"intermediate_steps": _results}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
output = self.combine_document_chain.run(
|
||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
output = self.combine_document_chain.run(input_documents=result_docs, **kwargs)
|
||||
return output, extra_return_dict
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -90,22 +89,19 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
"""
|
||||
results = self.llm_chain.apply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
@@ -113,8 +109,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""
|
||||
results = await self.llm_chain.aapply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
@@ -86,31 +85,29 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
res = self.initial_llm_chain.predict(**inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
res = self.refine_llm_chain.predict(**inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||
res = await self.initial_llm_chain.apredict(**inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||
res = await self.refine_llm_chain.apredict(**inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
@@ -76,21 +75,19 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
return self.llm_chain.predict(**inputs), {}
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
|
||||
return await self.llm_chain.apredict(**inputs), {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
"""Chain for applying constitutional principles to the outputs of another chain."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class ConstitutionalChain(Chain):
|
||||
@@ -19,19 +18,14 @@ class ConstitutionalChain(Chain):
|
||||
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.chains import LLMChain, ConstitutionalChain
|
||||
from langchain.chains.constitutional_ai.models \
|
||||
import ConstitutionalPrinciple
|
||||
|
||||
llm = OpenAI()
|
||||
|
||||
qa_prompt = PromptTemplate(
|
||||
template="Q: {question} A:",
|
||||
input_variables=["question"],
|
||||
)
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
qa_chain = LLMChain(llm=OpenAI(), prompt=qa_prompt)
|
||||
|
||||
constitutional_chain = ConstitutionalChain.from_llm(
|
||||
llm=llm,
|
||||
chain=qa_chain,
|
||||
constitutional_principles=[
|
||||
ConstitutionalPrinciple(
|
||||
@@ -87,16 +81,11 @@ class ConstitutionalChain(Chain):
|
||||
"""Defines the output keys."""
|
||||
return ["output"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
response = self.chain.run(**inputs)
|
||||
input_prompt = self.chain.prompt.format(**inputs)
|
||||
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
text="Initial response: " + response + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="yellow",
|
||||
@@ -109,7 +98,6 @@ class ConstitutionalChain(Chain):
|
||||
input_prompt=input_prompt,
|
||||
output_from_model=response,
|
||||
critique_request=constitutional_principle.critique_request,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
critique = self._parse_critique(
|
||||
output_string=raw_critique,
|
||||
@@ -123,23 +111,22 @@ class ConstitutionalChain(Chain):
|
||||
critique_request=constitutional_principle.critique_request,
|
||||
critique=critique,
|
||||
revision_request=constitutional_principle.revision_request,
|
||||
callbacks=_run_manager.get_child(),
|
||||
).strip()
|
||||
response = revision
|
||||
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
text=f"Applying {constitutional_principle.name}..." + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="green",
|
||||
)
|
||||
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
text="Critique: " + critique + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="blue",
|
||||
)
|
||||
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
text="Updated response: " + revision + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="yellow",
|
||||
|
||||
@@ -8,11 +8,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
@@ -20,7 +15,7 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseMessage, BaseRetriever, Document
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
# Depending on the memory type and configuration, the chat history format may differ.
|
||||
@@ -86,20 +81,14 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
chat_history_str = get_chat_history(inputs["chat_history"])
|
||||
|
||||
if chat_history_str:
|
||||
callbacks = _run_manager.get_child()
|
||||
new_question = self.question_generator.run(
|
||||
question=question, chat_history=chat_history_str, callbacks=callbacks
|
||||
question=question, chat_history=chat_history_str
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
@@ -107,9 +96,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
answer = self.combine_docs_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
||||
)
|
||||
answer = self.combine_docs_chain.run(input_documents=docs, **new_inputs)
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
@@ -119,19 +106,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
chat_history_str = get_chat_history(inputs["chat_history"])
|
||||
if chat_history_str:
|
||||
callbacks = _run_manager.get_child()
|
||||
new_question = await self.question_generator.arun(
|
||||
question=question, chat_history=chat_history_str, callbacks=callbacks
|
||||
question=question, chat_history=chat_history_str
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
@@ -139,9 +120,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
answer = await self.combine_docs_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
||||
)
|
||||
answer = await self.combine_docs_chain.arun(input_documents=docs, **new_inputs)
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -52,25 +51,18 @@ class GraphQAChain(Chain):
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
entity_extraction_chain=entity_chain,
|
||||
**kwargs,
|
||||
)
|
||||
return cls(qa_chain=qa_chain, entity_extraction_chain=entity_chain, **kwargs)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Extract entities, look up info and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
entity_string = self.entity_extraction_chain.run(question)
|
||||
|
||||
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
self.callback_manager.on_text(
|
||||
"Entities Extracted:", end="\n", verbose=self.verbose
|
||||
)
|
||||
self.callback_manager.on_text(
|
||||
entity_string, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
entities = get_entities(entity_string)
|
||||
@@ -78,10 +70,9 @@ class GraphQAChain(Chain):
|
||||
for entity in entities:
|
||||
triplets = self.graph.get_entity_knowledge(entity)
|
||||
context += "\n".join(triplets)
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=_run_manager.get_child(),
|
||||
self.callback_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
self.callback_manager.on_text(
|
||||
context, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
result = self.qa_chain({"question": question, "context": context})
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
|
||||
@@ -4,12 +4,11 @@ https://arxiv.org/abs/2212.10496
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -58,27 +57,18 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
embeddings = self.embed_documents(documents)
|
||||
return self.combine_embeddings(embeddings)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Call the internal llm chain."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
return self.llm_chain._call(inputs)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
base_embeddings: Embeddings,
|
||||
prompt_key: str,
|
||||
**kwargs: Any,
|
||||
cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str
|
||||
) -> HypotheticalDocumentEmbedder:
|
||||
"""Load and use LLMChain for a specific prompt key."""
|
||||
prompt = PROMPT_MAP[prompt_key]
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
||||
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -5,19 +5,11 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
@@ -61,40 +53,21 @@ class LLMChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.apply([inputs])[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
return self.llm.generate_prompt(prompts, stop)
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list)
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
return await self.llm.agenerate_prompt(prompts, stop)
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
@@ -106,8 +79,7 @@ class LLMChain(Chain):
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@@ -116,9 +88,7 @@ class LLMChain(Chain):
|
||||
return prompts, stop
|
||||
|
||||
async def aprep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
@@ -130,8 +100,12 @@ class LLMChain(Chain):
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
_text, end="\n", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@@ -139,45 +113,15 @@ class LLMChain(Chain):
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
response = self.generate(input_list)
|
||||
return self.create_outputs(response)
|
||||
|
||||
async def aapply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
await run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
response = await self.agenerate(input_list)
|
||||
return self.create_outputs(response)
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
@@ -187,19 +131,13 @@ class LLMChain(Chain):
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = await self.agenerate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return (await self.aapply([inputs]))[0]
|
||||
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
def predict(self, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
@@ -210,13 +148,12 @@ class LLMChain(Chain):
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
return self(kwargs)[self.output_key]
|
||||
|
||||
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
async def apredict(self, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
@@ -227,33 +164,31 @@ class LLMChain(Chain):
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
|
||||
return (await self.acall(kwargs))[self.output_key]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
result = self.predict(**kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
async def apredict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
self, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call apredict and then parse the results."""
|
||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||
result = await self.apredict(**kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
result = self.apply(input_list)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
@@ -267,10 +202,10 @@ class LLMChain(Chain):
|
||||
return result
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = await self.aapply(input_list, callbacks=callbacks)
|
||||
result = await self.aapply(input_list)
|
||||
return self._parse_result(result)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,23 +1,15 @@
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import Extra
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import OutputParserException
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMBashChain(Chain):
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
||||
@@ -26,17 +18,14 @@ class LLMBashChain(Chain):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMBashChain, OpenAI
|
||||
llm_bash = LLMBashChain.from_llm(OpenAI())
|
||||
llm_bash = LLMBashChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated]"""
|
||||
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -44,26 +33,6 @@ class LLMBashChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMBashChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain or using the from_llm class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@root_validator
|
||||
def validate_prompt(cls, values: Dict) -> Dict:
|
||||
if values["llm_chain"].prompt.output_parser is None:
|
||||
raise ValueError(
|
||||
"The prompt used by llm_chain is expected to have an output_parser."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
@@ -80,45 +49,31 @@ class LLMBashChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
|
||||
bash_executor = BashProcess()
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
|
||||
t = llm_executor.predict(question=inputs[self.input_key])
|
||||
self.callback_manager.on_text(t, color="green", verbose=self.verbose)
|
||||
|
||||
t = self.llm_chain.predict(
|
||||
question=inputs[self.input_key], callbacks=_run_manager.get_child()
|
||||
)
|
||||
_run_manager.on_text(t, color="green", verbose=self.verbose)
|
||||
t = t.strip()
|
||||
try:
|
||||
command_list = self.llm_chain.prompt.output_parser.parse(t) # type: ignore[union-attr]
|
||||
except OutputParserException as e:
|
||||
_run_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if t.startswith("```bash"):
|
||||
# Split the string into a list of substrings
|
||||
command_list = t.split("\n")
|
||||
print(command_list)
|
||||
|
||||
if self.verbose:
|
||||
_run_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(command_list), color="yellow", verbose=self.verbose
|
||||
)
|
||||
output = self.bash_process.run(command_list)
|
||||
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
# Remove the first and last substrings
|
||||
command_list = [s for s in command_list[1:-1]]
|
||||
output = bash_executor.run(command_list)
|
||||
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {t}")
|
||||
return {self.output_key: output}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_bash_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMBashChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
# flake8: noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
|
||||
|
||||
@@ -25,36 +19,4 @@ That is the format. Begin!
|
||||
|
||||
Question: {question}"""
|
||||
|
||||
|
||||
class BashOutputParser(BaseOutputParser):
|
||||
"""Parser for bash output."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
if "```bash" in text:
|
||||
return self.get_code_blocks(text)
|
||||
else:
|
||||
raise OutputParserException(
|
||||
f"Failed to parse bash output. Got: {text}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_code_blocks(t: str) -> List[str]:
|
||||
"""Get multiple code blocks from the LLM result."""
|
||||
code_blocks: List[str] = []
|
||||
# Bash markdown code blocks
|
||||
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
|
||||
for match in pattern.finditer(t):
|
||||
matched = match.group(1).strip()
|
||||
if matched:
|
||||
code_blocks.extend(
|
||||
[line for line in matched.split("\n") if line.strip()]
|
||||
)
|
||||
|
||||
return code_blocks
|
||||
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
output_parser=BashOutputParser(),
|
||||
)
|
||||
PROMPT = PromptTemplate(input_variables=["question"], template=_PROMPT_TEMPLATE)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""Chain for question-answering with self-verification."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_checker.prompt import (
|
||||
@@ -20,48 +18,6 @@ from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
|
||||
def _load_question_to_checked_assertions_chain(
|
||||
llm: BaseLLM,
|
||||
create_draft_answer_prompt: PromptTemplate,
|
||||
list_assertions_prompt: PromptTemplate,
|
||||
check_assertions_prompt: PromptTemplate,
|
||||
revised_answer_prompt: PromptTemplate,
|
||||
) -> SequentialChain:
|
||||
create_draft_answer_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=create_draft_answer_prompt,
|
||||
output_key="statement",
|
||||
)
|
||||
list_assertions_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=list_assertions_prompt,
|
||||
output_key="assertions",
|
||||
)
|
||||
check_assertions_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
)
|
||||
revised_answer_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=revised_answer_prompt,
|
||||
output_key="revised_statement",
|
||||
)
|
||||
chains = [
|
||||
create_draft_answer_chain,
|
||||
list_assertions_chain,
|
||||
check_assertions_chain,
|
||||
revised_answer_chain,
|
||||
]
|
||||
question_to_checked_assertions_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=["question"],
|
||||
output_variables=["revised_statement"],
|
||||
verbose=True,
|
||||
)
|
||||
return question_to_checked_assertions_chain
|
||||
|
||||
|
||||
class LLMCheckerChain(Chain):
|
||||
"""Chain for question-answering with self-verification.
|
||||
|
||||
@@ -70,21 +26,16 @@ class LLMCheckerChain(Chain):
|
||||
|
||||
from langchain import OpenAI, LLMCheckerChain
|
||||
llm = OpenAI(temperature=0.7)
|
||||
checker_chain = LLMCheckerChain.from_llm(llm)
|
||||
checker_chain = LLMCheckerChain(llm=llm)
|
||||
"""
|
||||
|
||||
question_to_checked_assertions_chain: SequentialChain
|
||||
|
||||
llm: Optional[BaseLLM] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
|
||||
"""[Deprecated]"""
|
||||
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT
|
||||
"""[Deprecated] Prompt to use when questioning the documents."""
|
||||
"""Prompt to use when questioning the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@@ -94,34 +45,6 @@ class LLMCheckerChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
|
||||
"Please instantiate with question_to_checked_assertions_chain "
|
||||
"or using the from_llm class method."
|
||||
)
|
||||
if (
|
||||
"question_to_checked_assertions_chain" not in values
|
||||
and values["llm"] is not None
|
||||
):
|
||||
question_to_checked_assertions_chain = (
|
||||
_load_question_to_checked_assertions_chain(
|
||||
values["llm"],
|
||||
values.get(
|
||||
"create_draft_answer_prompt", CREATE_DRAFT_ANSWER_PROMPT
|
||||
),
|
||||
values.get("list_assertions_prompt", LIST_ASSERTIONS_PROMPT),
|
||||
values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT),
|
||||
values.get("revised_answer_prompt", REVISED_ANSWER_PROMPT),
|
||||
)
|
||||
)
|
||||
values[
|
||||
"question_to_checked_assertions_chain"
|
||||
] = question_to_checked_assertions_chain
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -138,43 +61,43 @@ class LLMCheckerChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
|
||||
output = self.question_to_checked_assertions_chain(
|
||||
{"question": question}, callbacks=_run_manager.get_child()
|
||||
create_draft_answer_chain = LLMChain(
|
||||
llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement"
|
||||
)
|
||||
list_assertions_chain = LLMChain(
|
||||
llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions"
|
||||
)
|
||||
check_assertions_chain = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
)
|
||||
|
||||
revised_answer_chain = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.revised_answer_prompt,
|
||||
output_key="revised_statement",
|
||||
)
|
||||
|
||||
chains = [
|
||||
create_draft_answer_chain,
|
||||
list_assertions_chain,
|
||||
check_assertions_chain,
|
||||
revised_answer_chain,
|
||||
]
|
||||
|
||||
question_to_checked_assertions_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=["question"],
|
||||
output_variables=["revised_statement"],
|
||||
verbose=True,
|
||||
)
|
||||
output = question_to_checked_assertions_chain({"question": question})
|
||||
return {self.output_key: output["revised_statement"]}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_checker_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT,
|
||||
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT,
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
|
||||
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMCheckerChain:
|
||||
question_to_checked_assertions_chain = (
|
||||
_load_question_to_checked_assertions_chain(
|
||||
llm,
|
||||
create_draft_answer_prompt,
|
||||
list_assertions_prompt,
|
||||
check_assertions_prompt,
|
||||
revised_answer_prompt,
|
||||
)
|
||||
)
|
||||
return cls(
|
||||
question_to_checked_assertions_chain=question_to_checked_assertions_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
import numexpr
|
||||
from pydantic import Extra, root_validator
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LLMMathChain(Chain):
|
||||
@@ -27,14 +20,13 @@ class LLMMathChain(Chain):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMMathChain, OpenAI
|
||||
llm_math = LLMMathChain.from_llm(OpenAI())
|
||||
llm_math = LLMMathChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
"""Prompt to use to translate to python if neccessary."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
@@ -44,19 +36,6 @@ class LLMMathChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMMathChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
@@ -89,17 +68,15 @@ class LLMMathChain(Chain):
|
||||
# Remove any leading and trailing brackets from the output
|
||||
return re.sub(r"^\[|\]$", "", output)
|
||||
|
||||
def _process_llm_result(
|
||||
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
def _process_llm_result(self, llm_output: str) -> Dict[str, str]:
|
||||
self.callback_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
@@ -109,19 +86,30 @@ class LLMMathChain(Chain):
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
llm_output, color="green", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(
|
||||
llm_output, color="green", verbose=self.verbose
|
||||
)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await self.callback_manager.on_text(
|
||||
output, color="yellow", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(
|
||||
output, color="yellow", verbose=self.verbose
|
||||
)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
@@ -131,44 +119,31 @@ class LLMMathChain(Chain):
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = self.llm_chain.predict(
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(
|
||||
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
|
||||
)
|
||||
return self._process_llm_result(llm_output, _run_manager)
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
llm_output = llm_executor.predict(
|
||||
question=inputs[self.input_key], stop=["```output"]
|
||||
)
|
||||
return self._process_llm_result(llm_output)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(
|
||||
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, _run_manager)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
inputs[self.input_key], verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
llm_output = await llm_executor.apredict(
|
||||
question=inputs[self.input_key], stop=["```output"]
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_math_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMMathChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Chain that hits a URL and then uses an LLM to parse results."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
@@ -62,14 +61,9 @@ class LLMRequestsChain(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
url = inputs[self.input_key]
|
||||
@@ -77,9 +71,7 @@ class LLMRequestsChain(Chain):
|
||||
# extract the text from the html
|
||||
soup = BeautifulSoup(res, "html.parser")
|
||||
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
|
||||
result = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
result = self.llm_chain.predict(**other_keys)
|
||||
return {self.output_key: result}
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""Chain for summarization with self-verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
@@ -31,48 +27,6 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
|
||||
)
|
||||
|
||||
|
||||
def _load_sequential_chain(
|
||||
llm: BaseLLM,
|
||||
create_assertions_prompt: PromptTemplate,
|
||||
check_assertions_prompt: PromptTemplate,
|
||||
revised_summary_prompt: PromptTemplate,
|
||||
are_all_true_prompt: PromptTemplate,
|
||||
verbose: bool = False,
|
||||
) -> SequentialChain:
|
||||
chain = SequentialChain(
|
||||
chains=[
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
prompt=create_assertions_prompt,
|
||||
output_key="assertions",
|
||||
verbose=verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
prompt=check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
verbose=verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
prompt=revised_summary_prompt,
|
||||
output_key="revised_summary",
|
||||
verbose=verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
output_key="all_true",
|
||||
prompt=are_all_true_prompt,
|
||||
verbose=verbose,
|
||||
),
|
||||
],
|
||||
input_variables=["summary"],
|
||||
output_variables=["all_true", "revised_summary"],
|
||||
verbose=verbose,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
class LLMSummarizationCheckerChain(Chain):
|
||||
"""Chain for question-answering with self-verification.
|
||||
|
||||
@@ -81,21 +35,16 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
|
||||
from langchain import OpenAI, LLMSummarizationCheckerChain
|
||||
llm = OpenAI(temperature=0.0)
|
||||
checker_chain = LLMSummarizationCheckerChain.from_llm(llm)
|
||||
checker_chain = LLMSummarizationCheckerChain(llm=llm)
|
||||
"""
|
||||
|
||||
sequential_chain: SequentialChain
|
||||
llm: Optional[BaseLLM] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
|
||||
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT
|
||||
"""[Deprecated]"""
|
||||
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT
|
||||
"""[Deprecated]"""
|
||||
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
@@ -108,25 +57,6 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
|
||||
"deprecated. Please instantiate with"
|
||||
" sequential_chain argument or using the from_llm class method."
|
||||
)
|
||||
if "sequential_chain" not in values and values["llm"] is not None:
|
||||
values["sequential_chain"] = _load_sequential_chain(
|
||||
values["llm"],
|
||||
values.get("create_assertions_prompt", CREATE_ASSERTIONS_PROMPT),
|
||||
values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT),
|
||||
values.get("revised_summary_prompt", REVISED_SUMMARY_PROMPT),
|
||||
values.get("are_all_true_prompt", ARE_ALL_TRUE_PROMPT),
|
||||
verbose=values.get("verbose", False),
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -143,21 +73,46 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
all_true = False
|
||||
count = 0
|
||||
output = None
|
||||
original_input = inputs[self.input_key]
|
||||
chain_input = original_input
|
||||
|
||||
while not all_true and count < self.max_checks:
|
||||
output = self.sequential_chain(
|
||||
{"summary": chain_input}, callbacks=_run_manager.get_child()
|
||||
chain = SequentialChain(
|
||||
chains=[
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.create_assertions_prompt,
|
||||
output_key="assertions",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.revised_summary_prompt,
|
||||
output_key="revised_summary",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
output_key="all_true",
|
||||
prompt=self.are_all_true_prompt,
|
||||
verbose=self.verbose,
|
||||
),
|
||||
],
|
||||
input_variables=["summary"],
|
||||
output_variables=["all_true", "revised_summary"],
|
||||
verbose=self.verbose,
|
||||
)
|
||||
output = chain({"summary": chain_input})
|
||||
count += 1
|
||||
|
||||
if output["all_true"].strip() == "True":
|
||||
@@ -176,24 +131,3 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_summarization_checker_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT,
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
|
||||
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT,
|
||||
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> LLMSummarizationCheckerChain:
|
||||
chain = _load_sequential_chain(
|
||||
llm,
|
||||
create_assertions_prompt,
|
||||
check_assertions_prompt,
|
||||
revised_summary_prompt,
|
||||
are_all_true_prompt,
|
||||
verbose=verbose,
|
||||
)
|
||||
return cls(sequential_chain=chain, verbose=verbose, **kwargs)
|
||||
|
||||
@@ -5,11 +5,10 @@ then combines the results with another one.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -33,26 +32,16 @@ class MapReduceChain(Chain):
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate,
|
||||
text_splitter: TextSplitter,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
|
||||
) -> MapReduceChain:
|
||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
|
||||
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain)
|
||||
combine_documents_chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
combine_document_chain=reduce_chain,
|
||||
callbacks=callbacks,
|
||||
llm_chain=llm_chain, combine_document_chain=reduce_chain
|
||||
)
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
text_splitter=text_splitter,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
combine_documents_chain=combine_documents_chain, text_splitter=text_splitter
|
||||
)
|
||||
|
||||
class Config:
|
||||
@@ -77,16 +66,9 @@ class MapReduceChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
# Split the larger text into smaller chunks.
|
||||
texts = self.text_splitter.split_text(inputs[self.input_key])
|
||||
docs = [Document(page_content=text) for text in texts]
|
||||
outputs = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child()
|
||||
)
|
||||
outputs = self.combine_documents_chain.run(input_documents=docs)
|
||||
return {self.output_key: outputs}
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -85,11 +84,7 @@ class OpenAIModerationChain(Chain):
|
||||
return error_str
|
||||
return text
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
text = inputs[self.input_key]
|
||||
results = self.client.create(text)
|
||||
output = self._moderate(text, results["results"][0])
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""Implement an LLM driven browser."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.natbot.prompt import PROMPT
|
||||
@@ -20,15 +18,14 @@ class NatBotChain(Chain):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import NatBotChain
|
||||
natbot = NatBotChain.from_default("Buy me a new hat.")
|
||||
from langchain import NatBotChain, OpenAI
|
||||
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
objective: str
|
||||
"""Objective that NatBot is tasked with completing."""
|
||||
llm: Optional[BaseLLM] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
input_url_key: str = "url" #: :meta private:
|
||||
input_browser_content_key: str = "browser_content" #: :meta private:
|
||||
previous_command: str = "" #: :meta private:
|
||||
@@ -40,24 +37,11 @@ class NatBotChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an NatBotChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_default "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_default(cls, objective: str) -> NatBotChain:
|
||||
"""Load with default LLMChain."""
|
||||
"""Load with default LLM."""
|
||||
llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
return cls(llm_chain=llm_chain, objective=objective)
|
||||
return cls(llm=llm, objective=objective)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -75,20 +59,15 @@ class NatBotChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
url = inputs[self.input_url_key]
|
||||
browser_content = inputs[self.input_browser_content_key]
|
||||
llm_cmd = self.llm_chain.predict(
|
||||
llm_cmd = llm_executor.predict(
|
||||
objective=self.objective,
|
||||
url=url[:100],
|
||||
previous_command=self.previous_command,
|
||||
browser_content=browser_content[:4500],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
llm_cmd = llm_cmd.strip()
|
||||
self.previous_command = llm_cmd
|
||||
|
||||
@@ -4,29 +4,24 @@ As in https://arxiv.org/pdf/2211.10435.pdf.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
|
||||
class PALChain(Chain):
|
||||
"""Implements Program-Aided Language Models."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated]"""
|
||||
prompt: BasePromptTemplate = MATH_PROMPT
|
||||
"""[Deprecated]"""
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate
|
||||
stop: str = "\n\n"
|
||||
get_answer_expr: str = "print(solution())"
|
||||
python_globals: Optional[Dict[str, Any]] = None
|
||||
@@ -40,19 +35,6 @@ class PALChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an PALChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the one of "
|
||||
"the class method constructors from_math_prompt, "
|
||||
"from_colored_object_prompt."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=MATH_PROMPT)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -72,16 +54,12 @@ class PALChain(Chain):
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
code = self.llm_chain.predict(
|
||||
stop=[self.stop], callbacks=_run_manager.get_child(), **inputs
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
code = llm_chain.predict(stop=[self.stop], **inputs)
|
||||
self.callback_manager.on_text(
|
||||
code, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(code, color="green", end="\n", verbose=self.verbose)
|
||||
repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals)
|
||||
res = repl.run(code + f"\n{self.get_answer_expr}")
|
||||
output = {self.output_key: res.strip()}
|
||||
@@ -92,9 +70,9 @@ class PALChain(Chain):
|
||||
@classmethod
|
||||
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
|
||||
"""Load PAL from math prompt."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=MATH_PROMPT)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
llm=llm,
|
||||
prompt=MATH_PROMPT,
|
||||
stop="\n\n",
|
||||
get_answer_expr="print(solution())",
|
||||
**kwargs,
|
||||
@@ -105,9 +83,9 @@ class PALChain(Chain):
|
||||
cls, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> PALChain:
|
||||
"""Load PAL from colored object prompt."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=COLORED_OBJECT_PROMPT)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
llm=llm,
|
||||
prompt=COLORED_OBJECT_PROMPT,
|
||||
stop="\n\n\n",
|
||||
get_answer_expr="print(answer)",
|
||||
**kwargs,
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
|
||||
@@ -5,12 +5,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
@@ -46,14 +45,11 @@ class QAGenerationChain(Chain):
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, List]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
docs = self.text_splitter.create_documents([inputs[self.input_key]])
|
||||
results = self.llm_chain.generate(
|
||||
[{"text": d.page_content} for d in docs], run_manager=run_manager
|
||||
)
|
||||
results = self.llm_chain.generate([{"text": d.page_content} for d in docs])
|
||||
qa = [json.loads(res[0].text) for res in results.generations]
|
||||
return {self.output_key: qa}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -8,11 +8,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -26,6 +21,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BaseQAWithSourcesChain(Chain, ABC):
|
||||
@@ -118,16 +114,9 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
docs = self._get_docs(inputs)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
answer = self.combine_documents_chain.run(input_documents=docs, **inputs)
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||
else:
|
||||
@@ -144,16 +133,9 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
docs = await self._aget_docs(inputs)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs)
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Load question answering with sources chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
@@ -15,6 +14,7 @@ from langchain.chains.qa_with_sources import (
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
"""LLM Chain for turning a user text query into a structured query."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.chains.query_constructor.parser import get_parser
|
||||
from langchain.chains.query_constructor.prompt import (
|
||||
DEFAULT_EXAMPLES,
|
||||
DEFAULT_PREFIX,
|
||||
DEFAULT_SCHEMA,
|
||||
DEFAULT_SUFFIX,
|
||||
EXAMPLE_PROMPT,
|
||||
)
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.output_parsers.structured import parse_json_markdown
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
ast_parse: Callable
|
||||
"""Callable that parses dict into internal representation of query language."""
|
||||
|
||||
def parse(self, text: str) -> StructuredQuery:
|
||||
try:
|
||||
expected_keys = ["query", "filter"]
|
||||
parsed = parse_json_markdown(text, expected_keys)
|
||||
if len(parsed["query"]) == 0:
|
||||
parsed["query"] = " "
|
||||
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
|
||||
parsed["filter"] = None
|
||||
else:
|
||||
parsed["filter"] = self.ast_parse(parsed["filter"])
|
||||
return StructuredQuery(query=parsed["query"], filter=parsed["filter"])
|
||||
except Exception as e:
|
||||
raise OutputParserException(
|
||||
f"Parsing text\n{text}\n raised following error:\n{e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_components(
|
||||
cls,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
) -> StructuredQueryOutputParser:
|
||||
ast_parser = get_parser(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
return cls(ast_parse=ast_parser.parse)
|
||||
|
||||
|
||||
def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
|
||||
info_dicts = {}
|
||||
for i in info:
|
||||
i_dict = dict(i)
|
||||
info_dicts[i_dict.pop("name")] = i_dict
|
||||
return json.dumps(info_dicts, indent=2).replace("{", "{{").replace("}", "}}")
|
||||
|
||||
|
||||
def _get_prompt(
|
||||
document_contents: str,
|
||||
attribute_info: Sequence[AttributeInfo],
|
||||
examples: Optional[List] = None,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
attribute_str = _format_attribute_info(attribute_info)
|
||||
examples = examples or DEFAULT_EXAMPLES
|
||||
allowed_comparators = allowed_comparators or list(Comparator)
|
||||
allowed_operators = allowed_operators or list(Operator)
|
||||
schema = DEFAULT_SCHEMA.format(
|
||||
allowed_comparators=" | ".join(allowed_comparators),
|
||||
allowed_operators=" | ".join(allowed_operators),
|
||||
)
|
||||
prefix = DEFAULT_PREFIX.format(schema=schema)
|
||||
suffix = DEFAULT_SUFFIX.format(
|
||||
i=len(examples) + 1, content=document_contents, attributes=attribute_str
|
||||
)
|
||||
output_parser = StructuredQueryOutputParser.from_components(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
return FewShotPromptTemplate(
|
||||
examples=DEFAULT_EXAMPLES,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
input_variables=["query"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
output_parser=output_parser,
|
||||
)
|
||||
|
||||
|
||||
def load_query_constructor_chain(
|
||||
llm: BaseLanguageModel,
|
||||
document_contents: str,
|
||||
attribute_info: List[AttributeInfo],
|
||||
examples: Optional[List] = None,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
prompt = _get_prompt(
|
||||
document_contents,
|
||||
attribute_info,
|
||||
examples=examples,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
)
|
||||
return LLMChain(llm=llm, prompt=prompt, **kwargs)
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Internal representation of a structured query language."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Visitor(ABC):
|
||||
"""Defines interface for IR translation using visitor pattern."""
|
||||
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None
|
||||
allowed_operators: Optional[Sequence[Operator]] = None
|
||||
|
||||
@abstractmethod
|
||||
def visit_operation(self, operation: Operation) -> Any:
|
||||
"""Translate an Operation."""
|
||||
|
||||
@abstractmethod
|
||||
def visit_comparison(self, comparison: Comparison) -> Any:
|
||||
"""Translate a Comparison."""
|
||||
|
||||
@abstractmethod
|
||||
def visit_structured_query(self, structured_query: StructuredQuery) -> Any:
|
||||
"""Translate a StructuredQuery."""
|
||||
|
||||
|
||||
def _to_snake_case(name: str) -> str:
|
||||
"""Convert a name into snake_case."""
|
||||
snake_case = ""
|
||||
for i, char in enumerate(name):
|
||||
if char.isupper() and i != 0:
|
||||
snake_case += "_" + char.lower()
|
||||
else:
|
||||
snake_case += char.lower()
|
||||
return snake_case
|
||||
|
||||
|
||||
class Expr(BaseModel):
|
||||
def accept(self, visitor: Visitor) -> Any:
|
||||
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
|
||||
self
|
||||
)
|
||||
|
||||
|
||||
class Operator(str, Enum):
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
NOT = "not"
|
||||
|
||||
|
||||
class Comparator(str, Enum):
|
||||
EQ = "eq"
|
||||
GT = "gt"
|
||||
GTE = "gte"
|
||||
LT = "lt"
|
||||
LTE = "lte"
|
||||
|
||||
|
||||
class FilterDirective(Expr, ABC):
|
||||
"""A filtering expression."""
|
||||
|
||||
|
||||
class Comparison(FilterDirective):
|
||||
"""A comparison to a value."""
|
||||
|
||||
comparator: Comparator
|
||||
attribute: str
|
||||
value: Any
|
||||
|
||||
|
||||
class Operation(FilterDirective):
|
||||
"""A logical operation over other directives."""
|
||||
|
||||
operator: Operator
|
||||
arguments: List[FilterDirective]
|
||||
|
||||
|
||||
class StructuredQuery(Expr):
|
||||
query: str
|
||||
filter: Optional[FilterDirective]
|
||||
@@ -1,120 +0,0 @@
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
try:
|
||||
from lark import Lark, Transformer, v_args
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
FilterDirective,
|
||||
Operation,
|
||||
Operator,
|
||||
)
|
||||
|
||||
GRAMMAR = """
|
||||
?program: func_call
|
||||
?expr: func_call
|
||||
| value
|
||||
|
||||
func_call: CNAME "(" [args] ")"
|
||||
|
||||
?value: SIGNED_INT -> int
|
||||
| SIGNED_FLOAT -> float
|
||||
| list
|
||||
| string
|
||||
| ("false" | "False" | "FALSE") -> false
|
||||
| ("true" | "True" | "TRUE") -> true
|
||||
|
||||
args: expr ("," expr)*
|
||||
string: /'[^']*'/ | ESCAPED_STRING
|
||||
list: "[" [args] "]"
|
||||
|
||||
%import common.CNAME
|
||||
%import common.ESCAPED_STRING
|
||||
%import common.SIGNED_FLOAT
|
||||
%import common.SIGNED_INT
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
|
||||
@v_args(inline=True)
|
||||
class QueryTransformer(Transformer):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.allowed_comparators = allowed_comparators
|
||||
self.allowed_operators = allowed_operators
|
||||
|
||||
def program(self, *items: Any) -> tuple:
|
||||
return items
|
||||
|
||||
def func_call(self, func_name: Any, *args: Any) -> FilterDirective:
|
||||
func = self._match_func_name(str(func_name))
|
||||
if isinstance(func, Comparator):
|
||||
return Comparison(comparator=func, attribute=args[0][0], value=args[0][1])
|
||||
return Operation(operator=func, arguments=args[0])
|
||||
|
||||
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
||||
if func_name in set(Comparator):
|
||||
if self.allowed_comparators is not None:
|
||||
if func_name not in self.allowed_comparators:
|
||||
raise ValueError(
|
||||
f"Received disallowed comparator {func_name}. Allowed "
|
||||
f"comparators are {self.allowed_comparators}"
|
||||
)
|
||||
return Comparator(func_name)
|
||||
elif func_name in set(Operator):
|
||||
if self.allowed_operators is not None:
|
||||
if func_name not in self.allowed_operators:
|
||||
raise ValueError(
|
||||
f"Received disallowed operator {func_name}. Allowed operators"
|
||||
f" are {self.allowed_operators}"
|
||||
)
|
||||
return Operator(func_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Received unrecognized function {func_name}. Valid functions are "
|
||||
f"{list(Operator) + list(Comparator)}"
|
||||
)
|
||||
|
||||
def args(self, *items: Any) -> tuple:
|
||||
return items
|
||||
|
||||
def false(self) -> bool:
|
||||
return False
|
||||
|
||||
def true(self) -> bool:
|
||||
return True
|
||||
|
||||
def list(self, item: Any) -> list:
|
||||
if item is None:
|
||||
return []
|
||||
return list(item)
|
||||
|
||||
def int(self, item: Any) -> int:
|
||||
return int(item)
|
||||
|
||||
def float(self, item: Any) -> float:
|
||||
return float(item)
|
||||
|
||||
def string(self, item: Any) -> str:
|
||||
# Remove escaped quotes
|
||||
return str(item).strip("\"'")
|
||||
|
||||
|
||||
def get_parser(
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
) -> Lark:
|
||||
transformer = QueryTransformer(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user