mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-13 07:52:48 +00:00
Compare commits
197 Commits
langchain-
...
vwp/tools_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d5ce154f7 | ||
|
|
67b5bb53e3 | ||
|
|
50f6895900 | ||
|
|
18138c6fc1 | ||
|
|
fc402b5d61 | ||
|
|
f078943a4c | ||
|
|
fede5f75b0 | ||
|
|
1b48ea8d73 | ||
|
|
eb9de308c5 | ||
|
|
9291f4a53b | ||
|
|
cd7d0eadee | ||
|
|
7ee24f4a15 | ||
|
|
6b05934ddf | ||
|
|
0e81e83466 | ||
|
|
baa26ab294 | ||
|
|
7af99e1abc | ||
|
|
225963a85e | ||
|
|
2e5b9389de | ||
|
|
48c7f95add | ||
|
|
1848a2ca93 | ||
|
|
76ba417413 | ||
|
|
1a6603ba1a | ||
|
|
0fcaf65152 | ||
|
|
0f9e59e093 | ||
|
|
46e43e5f67 | ||
|
|
3b14bbd8f1 | ||
|
|
a8431deb32 | ||
|
|
64891af07f | ||
|
|
d8dcd39f87 | ||
|
|
e8fbb9dc8e | ||
|
|
5da0567f5f | ||
|
|
f7639ba150 | ||
|
|
45578e82c4 | ||
|
|
e83e178252 | ||
|
|
40ebacb3ad | ||
|
|
73ca8e9164 | ||
|
|
84a020c572 | ||
|
|
c971fefaab | ||
|
|
e5f8878a82 | ||
|
|
588bab73fa | ||
|
|
c58f4c504e | ||
|
|
a7cf0b247f | ||
|
|
f057c5b118 | ||
|
|
6104cf4c01 | ||
|
|
453e4d2ce8 | ||
|
|
f621c3cb48 | ||
|
|
81e21ba4fd | ||
|
|
48aa248e5c | ||
|
|
ec060418d6 | ||
|
|
a269024194 | ||
|
|
59af5c5d28 | ||
|
|
43343fdfc8 | ||
|
|
2a28bb0089 | ||
|
|
bac368b1aa | ||
|
|
a1a795044a | ||
|
|
6bbe43e7c7 | ||
|
|
54d3cdcf61 | ||
|
|
695901c5d8 | ||
|
|
b764817c05 | ||
|
|
d371904151 | ||
|
|
337b55a1b9 | ||
|
|
334ddf0cb3 | ||
|
|
7455df7820 | ||
|
|
f3932cdf41 | ||
|
|
b7999329c4 | ||
|
|
7e6fb35238 | ||
|
|
74fd5a98e0 | ||
|
|
9bc8df63dc | ||
|
|
39806c8953 | ||
|
|
29958c0b37 | ||
|
|
d644a34f6c | ||
|
|
10a7946c8e | ||
|
|
47d30c88f7 | ||
|
|
5a2c53b978 | ||
|
|
f47a9e4d1f | ||
|
|
1f673e703a | ||
|
|
ac9bcf4886 | ||
|
|
bbc43cf842 | ||
|
|
b4c94f05a8 | ||
|
|
049181e883 | ||
|
|
b693a6a9a2 | ||
|
|
3ffa814638 | ||
|
|
7f0edd8353 | ||
|
|
e607ba5df8 | ||
|
|
482348e852 | ||
|
|
1edbbedf0a | ||
|
|
87c2564624 | ||
|
|
a8e81f8e0e | ||
|
|
7877088ead | ||
|
|
24bd1c4964 | ||
|
|
91032df759 | ||
|
|
cc56ed88aa | ||
|
|
46674f4a6b | ||
|
|
ea50abfffe | ||
|
|
a47b127907 | ||
|
|
3c4fdf5285 | ||
|
|
24dd0fcbb0 | ||
|
|
d7ac46f62b | ||
|
|
794c342457 | ||
|
|
2966b50608 | ||
|
|
6bfd97b24d | ||
|
|
e69208f362 | ||
|
|
22ad33a8e1 | ||
|
|
e44317878b | ||
|
|
2402bb5b57 | ||
|
|
a61aa37010 | ||
|
|
6cefec4c65 | ||
|
|
5c8d6fa791 | ||
|
|
4917c71695 | ||
|
|
ca98b3e519 | ||
|
|
fae3eb7223 | ||
|
|
6544d2bc6f | ||
|
|
75c097dbdf | ||
|
|
6214b15d53 | ||
|
|
a21fc19d91 | ||
|
|
385d9271eb | ||
|
|
0ed4b1050f | ||
|
|
312cb3fd88 | ||
|
|
38a958eb30 | ||
|
|
0471854072 | ||
|
|
1b3bf86486 | ||
|
|
6100ad65b1 | ||
|
|
ae6bda90fc | ||
|
|
f44e275e1e | ||
|
|
a83c4a7711 | ||
|
|
17cbc6a5dd | ||
|
|
95fbd29353 | ||
|
|
14a2599bd2 | ||
|
|
f17c7bbe83 | ||
|
|
b253b0b0d9 | ||
|
|
a588e5a311 | ||
|
|
b23e1de43b | ||
|
|
b72d9c9d77 | ||
|
|
4a08ffc2e0 | ||
|
|
31ee671894 | ||
|
|
e5f184c7ba | ||
|
|
6d07bafda5 | ||
|
|
eb47767e9e | ||
|
|
9f40c09c86 | ||
|
|
b7dad1b6bf | ||
|
|
e808444b79 | ||
|
|
f400386865 | ||
|
|
f3ab7c2a9f | ||
|
|
4b071a69d1 | ||
|
|
0a7ca1014f | ||
|
|
326c2c2474 | ||
|
|
8feb416664 | ||
|
|
4003a79b35 | ||
|
|
994027771e | ||
|
|
fcf610bf31 | ||
|
|
a4a4c59e97 | ||
|
|
3d762f2b8b | ||
|
|
69d60041af | ||
|
|
38c3478be4 | ||
|
|
c0fd62cd6f | ||
|
|
471f961c8e | ||
|
|
a0f76dd40b | ||
|
|
3492a93dc4 | ||
|
|
bc76e0613c | ||
|
|
692baa797d | ||
|
|
15535b913d | ||
|
|
6d55489419 | ||
|
|
be958d98ad | ||
|
|
0f7d997bb0 | ||
|
|
87c046858b | ||
|
|
eecd5795f4 | ||
|
|
181840dcb4 | ||
|
|
d2b2772272 | ||
|
|
353e96cb66 | ||
|
|
a8b1bb6c4c | ||
|
|
c08f644d6d | ||
|
|
e12dc1321a | ||
|
|
c31cdc6d8d | ||
|
|
7f4eb81be7 | ||
|
|
ba167800dd | ||
|
|
96f33ed3ae | ||
|
|
0bfa2c9216 | ||
|
|
1a6e8865bf | ||
|
|
fd948bef64 | ||
|
|
bd9d9412b7 | ||
|
|
2ed4649e50 | ||
|
|
da27d8713d | ||
|
|
5dcb44ee1d | ||
|
|
15c0fa5e7e | ||
|
|
1fc3941430 | ||
|
|
8ae809af67 | ||
|
|
6cd653deb4 | ||
|
|
e953d2cf93 | ||
|
|
50668693d7 | ||
|
|
6fec15b6fb | ||
|
|
7bcdc66b99 | ||
|
|
4cdd19bd4e | ||
|
|
90cef7b53a | ||
|
|
675e27c136 | ||
|
|
fa4a4f2940 | ||
|
|
55c7964e4e | ||
|
|
3cc2ce6ac9 |
@@ -61,7 +61,6 @@
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks import AimCallbackHandler, StdOutCallbackHandler"
|
||||
]
|
||||
},
|
||||
@@ -109,8 +108,8 @@
|
||||
" experiment_name=\"scenario 1: OpenAI LLM\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), aim_callback])\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
|
||||
"callbacks = [StdOutCallbackHandler(), aim_callback]\n",
|
||||
"llm = OpenAI(temperature=0, callbacks=callbacks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -177,7 +176,7 @@
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
|
||||
"\n",
|
||||
"test_prompts = [\n",
|
||||
" {\"title\": \"documentary about good video games that push the boundary of game design\"},\n",
|
||||
@@ -249,13 +248,12 @@
|
||||
],
|
||||
"source": [
|
||||
"# scenario 3 - Agent with Tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
" callbacks=callbacks,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
|
||||
|
||||
@@ -79,7 +79,6 @@
|
||||
"source": [
|
||||
"from datetime import datetime\n",
|
||||
"from langchain.callbacks import ClearMLCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"# Setup and use the ClearML Callback\n",
|
||||
@@ -93,9 +92,9 @@
|
||||
" complexity_metrics=True,\n",
|
||||
" stream_logs=True\n",
|
||||
")\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), clearml_callback])\n",
|
||||
"callbacks = [StdOutCallbackHandler(), clearml_callback]\n",
|
||||
"# Get the OpenAI model ready to go\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
|
||||
"llm = OpenAI(temperature=0, callbacks=callbacks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -523,13 +522,12 @@
|
||||
"from langchain.agents import AgentType\n",
|
||||
"\n",
|
||||
"# SCENARIO 2 - Agent with Tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
" callbacks=callbacks,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
" \"Who is the wife of the person who sang summer of 69?\"\n",
|
||||
|
||||
@@ -121,7 +121,6 @@
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"comet_callback = CometCallbackHandler(\n",
|
||||
@@ -131,8 +130,8 @@
|
||||
" tags=[\"llm\"],\n",
|
||||
" visualizations=[\"dep\"],\n",
|
||||
")\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9, callbacks=callbacks, verbose=True)\n",
|
||||
"\n",
|
||||
"llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\", \"Tell me a fact\"] * 3)\n",
|
||||
"print(\"LLM result\", llm_result)\n",
|
||||
@@ -153,7 +152,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
@@ -164,15 +162,14 @@
|
||||
" stream_logs=True,\n",
|
||||
" tags=[\"synopsis-chain\"],\n",
|
||||
")\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
|
||||
"\n",
|
||||
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
|
||||
"\n",
|
||||
"test_prompts = [{\"title\": \"Documentary about Bigfoot in Paris\"}]\n",
|
||||
"print(synopsis_chain.apply(test_prompts))\n",
|
||||
@@ -194,7 +191,6 @@
|
||||
"source": [
|
||||
"from langchain.agents import initialize_agent, load_tools\n",
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"comet_callback = CometCallbackHandler(\n",
|
||||
@@ -203,15 +199,15 @@
|
||||
" stream_logs=True,\n",
|
||||
" tags=[\"agent\"],\n",
|
||||
")\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
|
||||
"\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=\"zero-shot-react-description\",\n",
|
||||
" callback_manager=manager,\n",
|
||||
" callbacks=callbacks,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
@@ -255,7 +251,6 @@
|
||||
"from rouge_score import rouge_scorer\n",
|
||||
"\n",
|
||||
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
@@ -298,10 +293,10 @@
|
||||
" tags=[\"custom_metrics\"],\n",
|
||||
" custom_metrics=rouge_score.compute_metric,\n",
|
||||
")\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
|
||||
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
|
||||
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
|
||||
"llm = OpenAI(temperature=0.9)\n",
|
||||
"\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)\n",
|
||||
"\n",
|
||||
"test_prompts = [\n",
|
||||
" {\n",
|
||||
@@ -323,7 +318,7 @@
|
||||
" \"\"\"\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"print(synopsis_chain.apply(test_prompts))\n",
|
||||
"print(synopsis_chain.apply(test_prompts, callbacks=callbacks))\n",
|
||||
"comet_callback.flush_tracker(synopsis_chain, finish=True)"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
This page covers how to use the `GPT4All` wrapper within LangChain. The tutorial is divided into two parts: installation and setup, followed by usage with an example.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the Python package with `pip install pyllamacpp`
|
||||
- Download a [GPT4All model](https://github.com/nomic-ai/pyllamacpp#supported-model) and place it in your desired directory
|
||||
|
||||
@@ -28,16 +29,16 @@ To stream the model's predictions, add in a CallbackManager.
|
||||
|
||||
```python
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
# There are many CallbackHandlers supported, such as
|
||||
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
|
||||
|
||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, verbose=True)
|
||||
callbacks = [StreamingStdOutCallbackHandler()]
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8)
|
||||
|
||||
# Generate text. Tokens are streamed through the callback manager.
|
||||
model("Once upon a time, ")
|
||||
model("Once upon a time, ", callbacks=callbacks)
|
||||
```
|
||||
|
||||
## Model File
|
||||
|
||||
@@ -50,7 +50,6 @@
|
||||
"source": [
|
||||
"from datetime import datetime\n",
|
||||
"from langchain.callbacks import WandbCallbackHandler, StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.llms import OpenAI"
|
||||
]
|
||||
},
|
||||
@@ -196,8 +195,8 @@
|
||||
" name=\"llm\",\n",
|
||||
" tags=[\"test\"],\n",
|
||||
")\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), wandb_callback])\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
|
||||
"callbacks = [StdOutCallbackHandler(), wandb_callback]\n",
|
||||
"llm = OpenAI(temperature=0, callbacks=callbacks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -484,7 +483,7 @@
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
|
||||
"\n",
|
||||
"test_prompts = [\n",
|
||||
" {\n",
|
||||
@@ -577,16 +576,15 @@
|
||||
],
|
||||
"source": [
|
||||
"# SCENARIO 3 - Agent with Tools\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools,\n",
|
||||
" llm,\n",
|
||||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" callback_manager=manager,\n",
|
||||
" verbose=True,\n",
|
||||
")\n",
|
||||
"agent.run(\n",
|
||||
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
|
||||
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\",\n",
|
||||
" callbacks=callbacks,\n",
|
||||
")\n",
|
||||
"wandb_callback.flush_tracker(agent, reset=False, finish=True)"
|
||||
]
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
"from langchain.agents import AgentType\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.tracers import LangChainTracer\n",
|
||||
"from aiohttp import ClientSession\n",
|
||||
"\n",
|
||||
@@ -307,14 +306,14 @@
|
||||
" # To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n",
|
||||
" # but you must manually close the client session at the end of your program/event loop\n",
|
||||
" aiosession = ClientSession()\n",
|
||||
" callbacks = [StdOutCallbackHandler()]\n",
|
||||
" for _ in questions:\n",
|
||||
" manager = CallbackManager([StdOutCallbackHandler()])\n",
|
||||
" llm = OpenAI(temperature=0, callback_manager=manager)\n",
|
||||
" async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n",
|
||||
" llm = OpenAI(temperature=0)\n",
|
||||
" async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n",
|
||||
" agents.append(\n",
|
||||
" initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
|
||||
" initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n",
|
||||
" )\n",
|
||||
" tasks = [async_agent.arun(q) for async_agent, q in zip(agents, questions)]\n",
|
||||
" tasks = [async_agent.arun(q, callbacks=callbacks) for async_agent, q in zip(agents, questions)]\n",
|
||||
" await asyncio.gather(*tasks)\n",
|
||||
" await aiosession.close()\n",
|
||||
"\n",
|
||||
@@ -376,14 +375,14 @@
|
||||
"aiosession = ClientSession()\n",
|
||||
"tracer = LangChainTracer()\n",
|
||||
"tracer.load_default_session()\n",
|
||||
"manager = CallbackManager([StdOutCallbackHandler(), tracer])\n",
|
||||
"callbacks = [StdOutCallbackHandler(), tracer]\n",
|
||||
"\n",
|
||||
"# Pass the manager into the llm if you want llm calls traced.\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager)\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"\n",
|
||||
"async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n",
|
||||
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
|
||||
"await async_agent.arun(questions[0])\n",
|
||||
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n",
|
||||
"await async_agent.arun(questions[0], callbacks=callbacks)\n",
|
||||
"await aiosession.close()"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -373,6 +373,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tools = get_tools(\"whats the weather?\")\n",
|
||||
"tool_names = [tool.name for tool in tools]\n",
|
||||
"agent = LLMSingleActionAgent(\n",
|
||||
" llm_chain=llm_chain, \n",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
188
docs/modules/chains/generic/custom_chain.ipynb
Normal file
188
docs/modules/chains/generic/custom_chain.ipynb
Normal file
@@ -0,0 +1,188 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "593f7553-7038-498e-96d4-8255e5ce34f0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Creating a custom Chain\n",
|
||||
"\n",
|
||||
"To implement your own custom chain you can subclass `BaseChain` and implement the following methods:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
|
||||
"metadata": {
|
||||
"tags": [],
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"from typing import Any, Dict, List, Optional\n",
|
||||
"\n",
|
||||
"from pydantic import Extra\n",
|
||||
"\n",
|
||||
"from langchain.base_language import BaseLanguageModel\n",
|
||||
"from langchain.callbacks.manager import (\n",
|
||||
" AsyncCallbackManagerForChainRun,\n",
|
||||
" CallbackManagerForChainRun,\n",
|
||||
")\n",
|
||||
"from langchain.chains.base import Chain\n",
|
||||
"from langchain.prompts.base import BasePromptTemplate\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyCustomChain(Chain):\n",
|
||||
" \"\"\"\n",
|
||||
" An example of a custom chain.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" prompt: BasePromptTemplate\n",
|
||||
" \"\"\"Prompt object to use.\"\"\"\n",
|
||||
" llm: BaseLanguageModel\n",
|
||||
" output_key: str = \"text\" #: :meta private:\n",
|
||||
"\n",
|
||||
" class Config:\n",
|
||||
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
|
||||
"\n",
|
||||
" extra = Extra.forbid\n",
|
||||
" arbitrary_types_allowed = True\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def input_keys(self) -> List[str]:\n",
|
||||
" \"\"\"Will be whatever keys the prompt expects.\n",
|
||||
"\n",
|
||||
" :meta private:\n",
|
||||
" \"\"\"\n",
|
||||
" return self.prompt.input_variables\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def output_keys(self) -> List[str]:\n",
|
||||
" \"\"\"Will always return text key.\n",
|
||||
"\n",
|
||||
" :meta private:\n",
|
||||
" \"\"\"\n",
|
||||
" return [self.output_key]\n",
|
||||
"\n",
|
||||
" def _call(\n",
|
||||
" self,\n",
|
||||
" inputs: Dict[str, Any],\n",
|
||||
" run_manager: Optional[CallbackManagerForChainRun] = None,\n",
|
||||
" ) -> Dict[str, str]:\n",
|
||||
" # Your custom chain logic goes here\n",
|
||||
" # This is just an example that mimics LLMChain\n",
|
||||
" prompt_value = self.prompt.format_prompt(**inputs)\n",
|
||||
" \n",
|
||||
" # Whenever you call a language model, or another chain, you should pass\n",
|
||||
" # a callback manager to it. This allows the inner run to be tracked by\n",
|
||||
" # any callbacks that are registered on the outer run.\n",
|
||||
" # You can always obtain a callback manager for this by calling\n",
|
||||
" # `run_manager.get_child()` as shown below.\n",
|
||||
" response = self.llm.generate_prompt(\n",
|
||||
" [prompt_value],\n",
|
||||
" callbacks=run_manager.get_child() if run_manager else None\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # If you want to log something about this run, you can do so by calling\n",
|
||||
" # methods on the `run_manager`, as shown below. This will trigger any\n",
|
||||
" # callbacks that are registered for that event.\n",
|
||||
" if run_manager:\n",
|
||||
" run_manager.on_text(\"Log something about this run\")\n",
|
||||
" \n",
|
||||
" return {self.output_key: response.generations[0][0].text}\n",
|
||||
"\n",
|
||||
" async def _acall(\n",
|
||||
" self,\n",
|
||||
" inputs: Dict[str, Any],\n",
|
||||
" run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
|
||||
" ) -> Dict[str, str]:\n",
|
||||
" # Your custom chain logic goes here\n",
|
||||
" # This is just an example that mimics LLMChain\n",
|
||||
" prompt_value = self.prompt.format_prompt(**inputs)\n",
|
||||
" \n",
|
||||
" # Whenever you call a language model, or another chain, you should pass\n",
|
||||
" # a callback manager to it. This allows the inner run to be tracked by\n",
|
||||
" # any callbacks that are registered on the outer run.\n",
|
||||
" # You can always obtain a callback manager for this by calling\n",
|
||||
" # `run_manager.get_child()` as shown below.\n",
|
||||
" response = await self.llm.agenerate_prompt(\n",
|
||||
" [prompt_value],\n",
|
||||
" callbacks=run_manager.get_child() if run_manager else None\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # If you want to log something about this run, you can do so by calling\n",
|
||||
" # methods on the `run_manager`, as shown below. This will trigger any\n",
|
||||
" # callbacks that are registered for that event.\n",
|
||||
" if run_manager:\n",
|
||||
" await run_manager.on_text(\"Log something about this run\")\n",
|
||||
" \n",
|
||||
" return {self.output_key: response.generations[0][0].text}\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def _chain_type(self) -> str:\n",
|
||||
" return \"my_custom_chain\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "18361f89",
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
|
||||
"Log something about this run\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
|
||||
"from langchain.chat_models.openai import ChatOpenAI\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chain = MyCustomChain(\n",
|
||||
" prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n",
|
||||
" llm=ChatOpenAI()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -589,7 +589,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.llm import LLMChain\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
@@ -597,7 +596,7 @@
|
||||
"# Construct a ConversationalRetrievalChain with a streaming llm for combine docs\n",
|
||||
"# and a separate, non-streaming llm for question generation\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"streaming_llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"\n",
|
||||
"question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n",
|
||||
"doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n",
|
||||
|
||||
@@ -80,9 +80,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -373,9 +373,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])\n"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -21,14 +21,13 @@
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI, Anthropic\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain.schema import HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "77f60a4b-f786-41f2-972e-e5bb8a48dcd5",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -79,7 +78,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"resp = llm(\"Write me a song about sparkling water.\")"
|
||||
]
|
||||
},
|
||||
@@ -95,7 +94,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "a35373f1-9ee6-4753-a343-5aee749b8527",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -136,7 +135,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"id": "22665f16-e05b-473c-a4bd-ad75744ea024",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -191,7 +190,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])"
|
||||
]
|
||||
},
|
||||
@@ -205,7 +204,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "eadae4ba-9f21-4ec8-845d-dd43b0edc2dc",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -245,7 +244,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = Anthropic(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"llm = Anthropic(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
|
||||
"llm(\"Write me a song about sparkling water.\")"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -40,7 +40,6 @@
|
||||
"source": [
|
||||
"from langchain import PromptTemplate, LLMChain\n",
|
||||
"from langchain.llms import GPT4All\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||
]
|
||||
},
|
||||
@@ -124,9 +123,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Callbacks support token-wise streaming\n",
|
||||
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
|
||||
"callbacks = [StreamingStdOutCallbackHandler()]\n",
|
||||
"# Verbose is required to pass to the callback manager\n",
|
||||
"llm = GPT4All(model=local_path, callback_manager=callback_manager, verbose=True)"
|
||||
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -6,16 +6,15 @@ First, you should install tracing and set up your environment properly.
|
||||
You can use either a locally hosted version of this (uses Docker) or a cloud hosted version (in closed alpha).
|
||||
If you're interested in using the hosted platform, please fill out the form [here](https://forms.gle/tRCEMSeopZf6TE3b6).
|
||||
|
||||
|
||||
- [Locally Hosted Setup](./tracing/local_installation.md)
|
||||
- [Cloud Hosted Setup](./tracing/hosted_installation.md)
|
||||
|
||||
## Tracing Walkthrough
|
||||
|
||||
When you first access the UI, you should see a page with your tracing sessions.
|
||||
An initial one "default" should already be created for you.
|
||||
A session is just a way to group traces together.
|
||||
If you click on a session, it will take you to a page with no recorded traces that says "No Runs."
|
||||
When you first access the UI, you should see a page with your tracing sessions.
|
||||
An initial one "default" should already be created for you.
|
||||
A session is just a way to group traces together.
|
||||
If you click on a session, it will take you to a page with no recorded traces that says "No Runs."
|
||||
You can create a new session with the new session form.
|
||||
|
||||

|
||||
@@ -35,7 +34,7 @@ We can keep on clicking further and further down to explore deeper and deeper.
|
||||
|
||||

|
||||
|
||||
We can also click on the "Explore" button of the top level run to dive even deeper.
|
||||
We can also click on the "Explore" button of the top level run to dive even deeper.
|
||||
Here, we can see the inputs and outputs in full, as well as all the nested traces.
|
||||
|
||||

|
||||
@@ -46,11 +45,12 @@ For example, here is the lowest level trace with the exact inputs/outputs to the
|
||||

|
||||
|
||||
## Changing Sessions
|
||||
|
||||
1. To initially record traces to a session other than `"default"`, you can set the `LANGCHAIN_SESSION` environment variable to the name of the session you want to record to:
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
||||
os.environ["LANGCHAIN_TRACING"] = "true"
|
||||
os.environ["LANGCHAIN_SESSION"] = "my_session" # Make sure this session actually exists. You can create a new session in the UI.
|
||||
```
|
||||
|
||||
|
||||
@@ -5,11 +5,6 @@ from typing import Optional
|
||||
|
||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.cache import BaseCache
|
||||
from langchain.callbacks import (
|
||||
set_default_callback_manager,
|
||||
set_handler,
|
||||
set_tracing_callback_manager,
|
||||
)
|
||||
from langchain.chains import (
|
||||
ConversationChain,
|
||||
LLMBashChain,
|
||||
@@ -67,7 +62,6 @@ del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
verbose: bool = False
|
||||
llm_cache: Optional[BaseCache] = None
|
||||
set_default_callback_manager()
|
||||
|
||||
# For backwards compatibility
|
||||
SerpAPIChain = SerpAPIWrapper
|
||||
@@ -119,7 +113,5 @@ __all__ = [
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"QAWithSourcesChain",
|
||||
"PALChain",
|
||||
"set_handler",
|
||||
"set_tracing_callback_manager",
|
||||
"LlamaCpp",
|
||||
]
|
||||
|
||||
@@ -13,7 +13,13 @@ import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
@@ -23,7 +29,6 @@ from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
)
|
||||
@@ -46,13 +51,17 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -61,13 +70,17 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -170,13 +183,17 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -185,13 +202,17 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -285,38 +306,52 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = self.llm_chain.run(
|
||||
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = await self.llm_chain.arun(
|
||||
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
@@ -368,37 +403,45 @@ class Agent(BaseSingleActionAgent):
|
||||
return thoughts
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(**full_inputs)
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(**full_inputs)
|
||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
def get_full_inputs(
|
||||
@@ -636,24 +679,27 @@ class AgentExecutor(Chain):
|
||||
|
||||
return True
|
||||
|
||||
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
||||
self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
def _return(
|
||||
self,
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if run_manager:
|
||||
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
|
||||
async def _areturn(
|
||||
self, output: AgentFinish, intermediate_steps: list
|
||||
self,
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_finish(
|
||||
if run_manager:
|
||||
await run_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
final_output = output.return_values
|
||||
@@ -667,13 +713,18 @@ class AgentExecutor(Chain):
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan(intermediate_steps, **inputs)
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
@@ -684,9 +735,8 @@ class AgentExecutor(Chain):
|
||||
actions = output
|
||||
result = []
|
||||
for agent_action in actions:
|
||||
self.callback_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
@@ -700,6 +750,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -708,6 +759,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
result.append((agent_action, observation))
|
||||
@@ -719,13 +771,18 @@ class AgentExecutor(Chain):
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
# Call the LLM to see what to do.
|
||||
output = await self.agent.aplan(intermediate_steps, **inputs)
|
||||
output = await self.agent.aplan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
@@ -738,12 +795,8 @@ class AgentExecutor(Chain):
|
||||
async def _aperform_agent_action(
|
||||
agent_action: AgentAction,
|
||||
) -> Tuple[AgentAction, str]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_action(
|
||||
if run_manager:
|
||||
await run_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
@@ -759,6 +812,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -767,6 +821,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
return agent_action, observation
|
||||
@@ -778,7 +833,11 @@ class AgentExecutor(Chain):
|
||||
|
||||
return list(result)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@@ -794,10 +853,16 @@ class AgentExecutor(Chain):
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = self._take_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return self._return(next_step_output, intermediate_steps)
|
||||
return self._return(
|
||||
next_step_output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
|
||||
intermediate_steps.extend(next_step_output)
|
||||
if len(next_step_output) == 1:
|
||||
@@ -813,7 +878,11 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
return self._return(output, intermediate_steps)
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@@ -831,7 +900,11 @@ class AgentExecutor(Chain):
|
||||
try:
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = await self._atake_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return await self._areturn(next_step_output, intermediate_steps)
|
||||
@@ -849,7 +922,9 @@ class AgentExecutor(Chain):
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
return await self._areturn(
|
||||
output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
except TimeoutError:
|
||||
# stop early when interrupted by the async timeout
|
||||
output = self.agent.return_stopped_response(
|
||||
|
||||
@@ -54,7 +54,7 @@ class FileManagementToolkit(BaseToolkit):
|
||||
tools: List[BaseTool] = []
|
||||
for tool in allowed_tools:
|
||||
tool_cls = _FILE_TOOLS[tool]
|
||||
tools.append(tool_cls(root_dir=self.root_dir))
|
||||
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
|
||||
return tools
|
||||
|
||||
|
||||
|
||||
@@ -28,13 +28,13 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
||||
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.memory import ReadOnlySharedMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.requests.tool import BaseRequestsTool
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import List, Optional
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.powerbi.prompt import QUESTION_TO_QUERY
|
||||
from langchain.tools.powerbi.tool import (
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.chat.output_parser import ChatOutputParser
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -13,7 +14,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseLanguageModel
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
|
||||
SUFFIX,
|
||||
TEMPLATE_TOOL_RESPONSE,
|
||||
)
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -24,7 +25,6 @@ from langchain.prompts.chat import (
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
HumanMessage,
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any, Optional, Sequence
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -103,8 +103,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool:
|
||||
return Tool(
|
||||
name="Calculator",
|
||||
description="Useful for when you need to answer questions about math.",
|
||||
func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run,
|
||||
coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun,
|
||||
func=LLMMathChain(llm=llm).run,
|
||||
coroutine=LLMMathChain(llm=llm).arun,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Interface for tools."""
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool, StructuredTool
|
||||
|
||||
|
||||
@@ -44,14 +49,44 @@ class Tool(BaseTool):
|
||||
)
|
||||
return tuple(all_args), {}
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
return self.func(*args, **kwargs)
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support async")
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
@@ -70,11 +105,17 @@ class InvalidTool(BaseTool):
|
||||
name = "invalid_tool"
|
||||
description = "Called when tool name is invalid."
|
||||
|
||||
def _run(self, tool_name: str) -> str:
|
||||
def _run(
|
||||
self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return f"{tool_name} is not a valid tool, try another one."
|
||||
|
||||
async def _arun(self, tool_name: str) -> str:
|
||||
async def _arun(
|
||||
self,
|
||||
tool_name: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
return f"{tool_name} is not a valid tool, try another one."
|
||||
|
||||
|
||||
55
langchain/base_language.py
Normal file
55
langchain/base_language.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Base class for all language models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the message."""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
@@ -1,80 +1,27 @@
|
||||
"""Callback handlers that allow listening to events in LangChain."""
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
from typing import Generator
|
||||
|
||||
from langchain.callbacks.aim_callback import AimCallbackHandler
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
|
||||
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
get_openai_callback,
|
||||
tracing_enabled,
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.tracers import SharedLangChainTracer
|
||||
from langchain.callbacks.tracers import LangChainTracer
|
||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
"""Return the shared callback manager."""
|
||||
return SharedCallbackManager()
|
||||
|
||||
|
||||
def set_handler(handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler."""
|
||||
callback = get_callback_manager()
|
||||
callback.set_handler(handler)
|
||||
|
||||
|
||||
def set_default_callback_manager() -> None:
|
||||
"""Set default callback manager."""
|
||||
default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
|
||||
if default_handler == "stdout":
|
||||
set_handler(StdOutCallbackHandler())
|
||||
elif default_handler == "langchain":
|
||||
session = os.environ.get("LANGCHAIN_SESSION")
|
||||
set_tracing_callback_manager(session)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LANGCHAIN_HANDLER should be one of `stdout` "
|
||||
f"or `langchain`, got {default_handler}"
|
||||
)
|
||||
|
||||
|
||||
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
|
||||
"""Set tracing callback manager."""
|
||||
handler = SharedLangChainTracer()
|
||||
callback = get_callback_manager()
|
||||
callback.set_handlers([handler, StdOutCallbackHandler()])
|
||||
if session_name is None:
|
||||
handler.load_default_session()
|
||||
else:
|
||||
try:
|
||||
handler.load_session(session_name)
|
||||
except Exception:
|
||||
raise ValueError(f"session {session_name} not found")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
handler = OpenAICallbackHandler()
|
||||
manager = get_callback_manager()
|
||||
manager.add_handler(handler)
|
||||
yield handler
|
||||
manager.remove_handler(handler)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CallbackManager",
|
||||
"AsyncCallbackManager",
|
||||
"OpenAICallbackHandler",
|
||||
"SharedCallbackManager",
|
||||
"StdOutCallbackHandler",
|
||||
"AimCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
@@ -82,8 +29,5 @@ __all__ = [
|
||||
"CometCallbackHandler",
|
||||
"AsyncIteratorCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"set_tracing_callback_manager",
|
||||
"set_default_callback_manager",
|
||||
"set_handler",
|
||||
"get_callback_manager",
|
||||
"tracing_enabled",
|
||||
]
|
||||
|
||||
@@ -1,19 +1,173 @@
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class BaseCallbackHandler(ABC):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
class LLMManagerMixin:
|
||||
"""Mixin for LLM callbacks."""
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return False
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class ToolManagerMixin:
|
||||
"""Mixin for tool callbacks."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
|
||||
class CallbackManagerMixin:
|
||||
"""Mixin for callback manager."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
|
||||
class RunManagerMixin:
|
||||
"""Mixin for run manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
@@ -30,480 +184,197 @@ class BaseCallbackHandler(ABC):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> Any:
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
async def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[str] = parent_run_id
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
if inherit:
|
||||
self.inheritable_handlers.append(handler)
|
||||
|
||||
@abstractmethod
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
self.add_handler(handler, inherit=inherit)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.set_handlers([handler])
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
@abstractmethod
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_end(response)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_end(outputs)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_error(error)
|
||||
|
||||
def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
for handler in self.handlers:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_agent_finish(finish, **kwargs)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
async def _handle_event_for_handler(
|
||||
handler: BaseCallbackHandler,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
verbose: bool,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
if verbose or handler.always_verbose:
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(event, *args, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the handler is async."""
|
||||
return True
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
async def _handle_event(
|
||||
self,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
verbose: bool,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Generic event handler for AsyncCallbackManager."""
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_handle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
|
||||
)
|
||||
for handler in self.handlers
|
||||
)
|
||||
def __copy__(self) -> "BaseCallbackManager":
|
||||
return self.__class__(
|
||||
self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
await self._handle_event(
|
||||
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
|
||||
def __deepcopy__(self, memo: dict) -> "BaseCallbackManager":
|
||||
return self.__class__(
|
||||
[copy.deepcopy(handler, memo) for handler in self.handlers],
|
||||
[copy.deepcopy(handler, memo) for handler in self.inheritable_handlers],
|
||||
self.parent_run_id,
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
await self._handle_event(
|
||||
"on_llm_new_token", "ignore_llm", verbose, token, **kwargs
|
||||
)
|
||||
|
||||
async def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
await self._handle_event(
|
||||
"on_llm_end", "ignore_llm", verbose, response, **kwargs
|
||||
)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
await self._handle_event(
|
||||
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
|
||||
)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
await self._handle_event(
|
||||
"on_chain_end", "ignore_chain", verbose, outputs, **kwargs
|
||||
)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
await self._handle_event(
|
||||
"on_chain_error", "ignore_chain", verbose, error, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
await self._handle_event(
|
||||
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self, output: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
await self._handle_event(
|
||||
"on_tool_end", "ignore_agent", verbose, output, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
await self._handle_event(
|
||||
"on_tool_error", "ignore_agent", verbose, error, **kwargs
|
||||
)
|
||||
|
||||
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when text is printed."""
|
||||
await self._handle_event("on_text", None, verbose, text, **kwargs)
|
||||
|
||||
async def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
await self._handle_event(
|
||||
"on_agent_action", "ignore_agent", verbose, action, **kwargs
|
||||
)
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when agent finishes."""
|
||||
await self._handle_event(
|
||||
"on_agent_finish", "ignore_agent", verbose, finish, **kwargs
|
||||
)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
731
langchain/callbacks/manager.py
Normal file
731
langchain/callbacks/manager.py
Normal file
@@ -0,0 +1,731 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, TypeVar, Union
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.base import TracerSession
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
|
||||
"tracing_callback", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
cb = OpenAICallbackHandler()
|
||||
openai_callback_var.set(cb)
|
||||
yield cb
|
||||
openai_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
cb = LangChainTracer()
|
||||
session = cb.load_session(session_name)
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
tracing_callback_var.set(None)
|
||||
|
||||
|
||||
def _handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
getattr(handler, event_name)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# TODO: switch this to use logging
|
||||
print(f"Error in {event_name} callback: {e}")
|
||||
|
||||
|
||||
async def _ahandle_event_for_handler(
|
||||
handler: BaseCallbackHandler,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(event, *args, **kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: switch this to use logging
|
||||
print(f"Error in {event_name} callback: {e}")
|
||||
|
||||
|
||||
async def _ahandle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Generic event handler for AsyncCallbackManager."""
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_ahandle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, *args, **kwargs
|
||||
)
|
||||
for handler in handlers
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
BRM = TypeVar("BRM", bound="BaseRunManager")
|
||||
|
||||
|
||||
class BaseRunManager(RunManagerMixin):
|
||||
"""Base class for run manager (a bound callback manager)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler],
|
||||
parent_run_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize run manager."""
|
||||
self.run_id = run_id
|
||||
self.handlers = handlers
|
||||
self.inheritable_handlers = inheritable_handlers
|
||||
self.parent_run_id = parent_run_id
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||
"""Return a manager that doesn't perform any operations."""
|
||||
return cls("", [], [])
|
||||
|
||||
|
||||
class RunManager(BaseRunManager):
|
||||
"""Sync Run Manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when text is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
None,
|
||||
text,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncRunManager(BaseRunManager):
|
||||
"""Async Run Manager."""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when text is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
None,
|
||||
text,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
"""Callback manager for LLM run."""
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token=token,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
"ignore_llm",
|
||||
response,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
"ignore_llm",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Async callback manager for LLM run."""
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
"ignore_llm",
|
||||
response,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
"ignore_llm",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def get_child(self) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
"ignore_chain",
|
||||
outputs,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
"ignore_chain",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
"ignore_agent",
|
||||
action,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
"ignore_agent",
|
||||
finish,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
||||
"""Async callback manager for chain run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
"ignore_chain",
|
||||
outputs,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
"ignore_chain",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
"ignore_agent",
|
||||
action,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
"ignore_agent",
|
||||
finish,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
||||
"""Callback manager for tool run."""
|
||||
|
||||
def get_child(self) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
"ignore_agent",
|
||||
output,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
"ignore_agent",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
"""Async callback manager for tool run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
"ignore_agent",
|
||||
output,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
"ignore_agent",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForChainRun:
|
||||
"""Run when chain starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_start",
|
||||
"ignore_chain",
|
||||
serialized,
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForChainRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: Optional[str] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForToolRun:
|
||||
"""Run when tool starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_start",
|
||||
"ignore_agent",
|
||||
serialized,
|
||||
input_str,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForToolRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
) -> CallbackManager:
|
||||
"""Configure the callback manager."""
|
||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the handler is async."""
|
||||
return True
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForLLMRun:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForChainRun:
|
||||
"""Run when chain starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_start",
|
||||
"ignore_chain",
|
||||
serialized,
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForChainRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: Optional[str] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForToolRun:
|
||||
"""Run when tool starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_start",
|
||||
"ignore_agent",
|
||||
serialized,
|
||||
input_str,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForToolRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Configure the callback manager."""
|
||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
||||
|
||||
|
||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
) -> T:
|
||||
"""Configure the callback manager."""
|
||||
callback_manager = callback_manager_cls([])
|
||||
if inheritable_callbacks or local_callbacks:
|
||||
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
||||
inheritable_callbacks_ = inheritable_callbacks or []
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=inheritable_callbacks_,
|
||||
inheritable_handlers=inheritable_callbacks_,
|
||||
)
|
||||
else:
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=inheritable_callbacks.handlers,
|
||||
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
|
||||
parent_run_id=inheritable_callbacks.parent_run_id,
|
||||
)
|
||||
callback_manager = copy.deepcopy(callback_manager)
|
||||
local_handlers_ = (
|
||||
local_callbacks
|
||||
if isinstance(local_callbacks, list)
|
||||
else (local_callbacks.handlers if local_callbacks else [])
|
||||
)
|
||||
for handler in local_handlers_:
|
||||
callback_manager.add_handler(copy.deepcopy(handler), False)
|
||||
|
||||
tracer = tracing_callback_var.get()
|
||||
open_ai = openai_callback_var.get()
|
||||
tracing_enabled_ = (
|
||||
os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None
|
||||
)
|
||||
if verbose or tracing_enabled_ or open_ai is not None:
|
||||
if verbose and not any(
|
||||
isinstance(handler, StdOutCallbackHandler)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(StdOutCallbackHandler(), False)
|
||||
|
||||
if tracing_enabled_ and not any(
|
||||
isinstance(handler, LangChainTracer)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
if tracer:
|
||||
callback_manager.add_handler(copy.deepcopy(tracer), True)
|
||||
else:
|
||||
handler = LangChainTracer()
|
||||
handler.load_default_session()
|
||||
callback_manager.add_handler(handler, True)
|
||||
if open_ai is not None and not any(
|
||||
isinstance(handler, OpenAICallbackHandler)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(open_ai, True)
|
||||
|
||||
return callback_manager
|
||||
@@ -1,127 +0,0 @@
|
||||
"""A shared CallbackManager."""
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class Singleton:
|
||||
"""A thread-safe singleton class that can be inherited from."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> Any:
|
||||
"""Create a new shared instance of the class."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
# Another thread could have created the instance
|
||||
# before we acquired the lock. So check that the
|
||||
# instance is still nonexistent.
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
"""A thread-safe singleton CallbackManager."""
|
||||
|
||||
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_end(response, **kwargs)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_error(error, **kwargs)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_end(outputs, **kwargs)
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_error(error, **kwargs)
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_finish(finish, **kwargs)
|
||||
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a callback to the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.add_handler(callback)
|
||||
|
||||
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Remove a callback from the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.remove_handler(callback)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.handlers = handlers
|
||||
@@ -1,12 +1,5 @@
|
||||
"""Tracers that record execution of LangChain runs."""
|
||||
|
||||
from langchain.callbacks.tracers.base import SharedTracer, Tracer
|
||||
from langchain.callbacks.tracers.langchain import BaseLangChainTracer
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
|
||||
|
||||
class SharedLangChainTracer(SharedTracer, BaseLangChainTracer):
|
||||
"""Shared tracer that records LangChain execution to LangChain endpoint."""
|
||||
|
||||
|
||||
class LangChainTracer(Tracer, BaseLangChainTracer):
|
||||
"""Tracer that records LangChain execution to LangChain endpoint."""
|
||||
__all__ = ["LangChainTracer"]
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Base interfaces for tracing runs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.shared import Singleton
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
@@ -16,7 +14,7 @@ from langchain.callbacks.tracers.schemas import (
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
@@ -26,13 +24,25 @@ class TracerException(Exception):
|
||||
class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Base interface for tracers."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
||||
self.session: Optional[TracerSession] = None
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
if isinstance(child_run, LLMRun):
|
||||
parent_run.child_llm_runs.append(child_run)
|
||||
elif isinstance(child_run, ChainRun):
|
||||
parent_run.child_chain_runs.append(child_run)
|
||||
elif isinstance(child_run, ToolRun):
|
||||
parent_run.child_tool_runs.append(child_run)
|
||||
else:
|
||||
raise TracerException(f"Invalid run type: {type(child_run)}")
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
@@ -42,15 +52,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
|
||||
@abstractmethod
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
|
||||
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
|
||||
"""NOT thread safe, do not call this method from multiple threads."""
|
||||
session_create = TracerSessionCreate(name=name, extra=kwargs)
|
||||
session = self._persist_session(session_create)
|
||||
self._session = session
|
||||
self.session = session
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
@@ -61,283 +67,232 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
|
||||
@_execution_order.setter
|
||||
@abstractmethod
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
|
||||
@_session.setter
|
||||
@abstractmethod
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
|
||||
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Start a trace for a run."""
|
||||
self._execution_order += 1
|
||||
|
||||
if self._stack:
|
||||
if not (
|
||||
isinstance(self._stack[-1], ChainRun)
|
||||
or isinstance(self._stack[-1], ToolRun)
|
||||
):
|
||||
if run.parent_uuid:
|
||||
parent_run = self.run_map[run.parent_uuid]
|
||||
if parent_run:
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException(
|
||||
"Cannot add child run to an LLM run. "
|
||||
"LLM runs are not allowed to have children."
|
||||
)
|
||||
self._add_child_run(parent_run, run)
|
||||
else:
|
||||
raise TracerException(
|
||||
f"Nested {run.__class__.__name__} can only be"
|
||||
f" logged inside a ChainRun or ToolRun"
|
||||
f"Parent run with UUID {run.parent_uuid} not found."
|
||||
)
|
||||
self._add_child_run(self._stack[-1], run)
|
||||
self._stack.append(run)
|
||||
|
||||
def _end_trace(self) -> None:
|
||||
self.run_map[run.uuid] = run
|
||||
|
||||
def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""End a trace for a run."""
|
||||
run = self._stack.pop()
|
||||
if not self._stack:
|
||||
self._execution_order = 1
|
||||
if not run.parent_uuid:
|
||||
self._persist_run(run)
|
||||
else:
|
||||
parent_run = self.run_map.get(run.parent_uuid)
|
||||
if parent_run is None:
|
||||
raise TracerException(
|
||||
f"Parent run with UUID {run.parent_uuid} not found."
|
||||
)
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
||||
if run.child_execution_order > parent_run.child_execution_order:
|
||||
parent_run.child_execution_order = run.child_execution_order
|
||||
self.run_map.pop(run.uuid)
|
||||
|
||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
if parent_run_id is None:
|
||||
return 1
|
||||
|
||||
parent_run = self.run_map.get(parent_run_id)
|
||||
if parent_run is None:
|
||||
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
|
||||
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
||||
|
||||
return parent_run.child_execution_order + 1
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
|
||||
if run_id is None:
|
||||
run_id = str(uuid4())
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
llm_run = LLMRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
prompts=prompts,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
session_id=self.session.id,
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Handle a new token for an LLM run."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: str, **kwargs: Any) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
||||
|
||||
llm_run = self.run_map.get(run_id)
|
||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].response = response
|
||||
|
||||
self._end_trace()
|
||||
llm_run.response = response
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
self._end_trace(llm_run)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Handle an error for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
||||
|
||||
llm_run = self.run_map.get(run_id)
|
||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
self._stack[-1].error = repr(error)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
|
||||
self._end_trace()
|
||||
llm_run.error = repr(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
self._end_trace(llm_run)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a chain run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
chain_run = ChainRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
session_id=self.session.id,
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], *, run_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""End a trace for a chain run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
chain_run = self.run_map.get(run_id)
|
||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].outputs = outputs
|
||||
|
||||
self._end_trace()
|
||||
chain_run.outputs = outputs
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
self._end_trace(chain_run)
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Handle an error for a chain run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
chain_run = self.run_map.get(run_id)
|
||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
|
||||
self._end_trace()
|
||||
chain_run.error = repr(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
self._end_trace(chain_run)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a tool run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
tool_run = ToolRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
# TODO: this is duplicate info as above, not needed.
|
||||
action=str(serialized),
|
||||
tool_input=input_str,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
session_id=self.session.id,
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: str, *, run_id: str, **kwargs: Any) -> None:
|
||||
"""End a trace for a tool run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
tool_run = self.run_map.get(run_id)
|
||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].output = output
|
||||
|
||||
self._end_trace()
|
||||
tool_run.output = output
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
self._end_trace(tool_run)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Handle an error for a tool run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
tool_run = self.run_map.get(run_id)
|
||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
tool_run.error = repr(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
self._end_trace(tool_run)
|
||||
|
||||
self._end_trace()
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Handle a text message."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Handle an agent finish message."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class Tracer(BaseTracer, ABC):
|
||||
"""A non-thread safe implementation of the BaseTracer interface."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a tracer."""
|
||||
self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
self._tracer_execution_order = 1
|
||||
self._tracer_session: Optional[TracerSession] = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TracerStack(threading.local):
|
||||
"""A stack of runs used for logging."""
|
||||
|
||||
stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list)
|
||||
execution_order: int = 1
|
||||
|
||||
|
||||
class SharedTracer(Singleton, BaseTracer, ABC):
|
||||
"""A thread-safe Singleton implementation of BaseTracer."""
|
||||
|
||||
_tracer_stack = TracerStack()
|
||||
_tracer_session = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack.stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_stack.execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_stack.execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
with self._lock:
|
||||
# TODO: currently, we are only checking current thread's stack.
|
||||
# Need to make sure that we are not in the middle of a trace
|
||||
# in any thread.
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
||||
def __copy__(self) -> BaseTracer:
|
||||
"""Copy the tracer."""
|
||||
return self
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
@@ -18,14 +17,17 @@ from langchain.callbacks.tracers.schemas import (
|
||||
)
|
||||
|
||||
|
||||
class BaseLangChainTracer(BaseTracer, ABC):
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
always_verbose: bool = True
|
||||
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
self.session = self.load_session(session_name)
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
@@ -59,54 +61,29 @@ class BaseLangChainTracer(BaseTracer, ABC):
|
||||
session = TracerSession(id=1, **session_create.dict())
|
||||
return session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
def _load_session(self, session_name: Optional[str] = None) -> TracerSession:
|
||||
"""Load a session from the tracer."""
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions?name={session_name}",
|
||||
headers=self._headers,
|
||||
)
|
||||
url = f"{self._endpoint}/sessions"
|
||||
if session_name:
|
||||
url += f"?name={session_name}"
|
||||
r = requests.get(url, headers=self._headers)
|
||||
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logging.warning(
|
||||
f"Failed to load session {session_name}, using empty session: {e}"
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSession(id=1)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
return self._load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions",
|
||||
headers=self._headers,
|
||||
)
|
||||
# Use the first session result
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to default session, using empty session: {e}")
|
||||
tracer_session = TracerSession(id=1)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
if isinstance(child_run, LLMRun):
|
||||
parent_run.child_llm_runs.append(child_run)
|
||||
elif isinstance(child_run, ChainRun):
|
||||
parent_run.child_chain_runs.append(child_run)
|
||||
else:
|
||||
parent_run.child_tool_runs.append(child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
||||
return self._load_session("default")
|
||||
|
||||
@@ -32,11 +32,13 @@ class TracerSession(TracerSessionBase):
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
id: Optional[Union[int, str]] = None
|
||||
uuid: str
|
||||
parent_uuid: Optional[str] = None
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
serialized: Dict[str, Any]
|
||||
session_id: int
|
||||
error: Optional[str] = None
|
||||
@@ -57,7 +59,6 @@ class ChainRun(BaseRun):
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolRun(BaseRun):
|
||||
@@ -69,7 +70,6 @@ class ToolRun(BaseRun):
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
|
||||
@@ -5,12 +5,16 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class APIChain(Chain):
|
||||
@@ -61,16 +65,21 @@ class APIChain(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.question_key]
|
||||
api_url = self.api_request_chain.predict(
|
||||
question=question, api_docs=self.api_docs
|
||||
)
|
||||
self.callback_manager.on_text(
|
||||
api_url, color="green", end="\n", verbose=self.verbose
|
||||
question=question,
|
||||
api_docs=self.api_docs,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
|
||||
api_response = self.requests_wrapper.get(api_url)
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
api_response, color="yellow", end="\n", verbose=self.verbose
|
||||
)
|
||||
answer = self.api_answer_chain.predict(
|
||||
@@ -78,19 +87,27 @@ class APIChain(Chain):
|
||||
api_docs=self.api_docs,
|
||||
api_url=api_url,
|
||||
api_response=api_response,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.question_key]
|
||||
api_url = await self.api_request_chain.apredict(
|
||||
question=question, api_docs=self.api_docs
|
||||
question=question,
|
||||
api_docs=self.api_docs,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
self.callback_manager.on_text(
|
||||
await _run_manager.on_text(
|
||||
api_url, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
api_response = await self.requests_wrapper.aget(api_url)
|
||||
self.callback_manager.on_text(
|
||||
await _run_manager.on_text(
|
||||
api_response, color="yellow", end="\n", verbose=self.verbose
|
||||
)
|
||||
answer = await self.api_answer_chain.apredict(
|
||||
@@ -98,6 +115,7 @@ class APIChain(Chain):
|
||||
api_docs=self.api_docs,
|
||||
api_url=api_url,
|
||||
api_response=api_response,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return {self.output_key: answer}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast
|
||||
from pydantic import BaseModel, Field
|
||||
from requests import Response
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
|
||||
from langchain.chains.api.openapi.response_chain import APIResponderChain
|
||||
from langchain.chains.base import Chain
|
||||
@@ -106,16 +107,21 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
else:
|
||||
return {self.output_key: output}
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
intermediate_steps = {}
|
||||
instructions = inputs[self.instructions_key]
|
||||
instructions = instructions[: self.max_text_length]
|
||||
_api_arguments = self.api_request_chain.predict_and_parse(
|
||||
instructions=instructions
|
||||
instructions=instructions, callbacks=_run_manager.get_child()
|
||||
)
|
||||
api_arguments = cast(str, _api_arguments)
|
||||
intermediate_steps["request_args"] = api_arguments
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
api_arguments, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
if api_arguments.startswith("ERROR"):
|
||||
@@ -141,18 +147,17 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
response_text = f"Error with message {str(e)}"
|
||||
response_text = response_text[: self.max_text_length]
|
||||
intermediate_steps["response_text"] = response_text
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
response_text, color="blue", end="\n", verbose=self.verbose
|
||||
)
|
||||
if self.api_response_chain is not None:
|
||||
_answer = self.api_response_chain.predict_and_parse(
|
||||
response=response_text,
|
||||
instructions=instructions,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
answer = cast(str, _answer)
|
||||
self.callback_manager.on_text(
|
||||
answer, color="yellow", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose)
|
||||
return self._get_output(answer, intermediate_steps)
|
||||
else:
|
||||
return self._get_output(response_text, intermediate_steps)
|
||||
@@ -188,6 +193,7 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
raw_response: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
# TODO: Handle async
|
||||
) -> "OpenAPIEndpointChain":
|
||||
@@ -198,12 +204,17 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
path_params=operation.path_params,
|
||||
)
|
||||
requests_chain = APIRequesterChain.from_llm_and_typescript(
|
||||
llm, typescript_definition=operation.to_typescript(), verbose=verbose
|
||||
llm,
|
||||
typescript_definition=operation.to_typescript(),
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if raw_response:
|
||||
response_chain = None
|
||||
else:
|
||||
response_chain = APIResponderChain.from_llm(llm, verbose=verbose)
|
||||
response_chain = APIResponderChain.from_llm(
|
||||
llm, verbose=verbose, callbacks=callbacks
|
||||
)
|
||||
_requests = requests or Requests()
|
||||
return cls(
|
||||
api_request_chain=requests_chain,
|
||||
@@ -213,5 +224,6 @@ class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
param_mapping=param_mapping,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -36,7 +37,11 @@ class APIRequesterChain(LLMChain):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_typescript(
|
||||
cls, llm: BaseLLM, typescript_definition: str, verbose: bool = True
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
typescript_definition: str,
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""Get the request parser."""
|
||||
output_parser = APIRequesterOutputParser()
|
||||
@@ -46,4 +51,4 @@ class APIRequesterChain(LLMChain):
|
||||
partial_variables={"schema": typescript_definition},
|
||||
input_variables=["instructions"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -35,7 +36,7 @@ class APIResponderChain(LLMChain):
|
||||
"""Get the response parser."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
|
||||
def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain:
|
||||
"""Get the response parser."""
|
||||
output_parser = APIResponderOutputParser()
|
||||
prompt = PromptTemplate(
|
||||
@@ -43,4 +44,4 @@ class APIResponderChain(LLMChain):
|
||||
output_parser=output_parser,
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import BaseMemory
|
||||
|
||||
|
||||
@@ -21,9 +29,8 @@ class Chain(BaseModel, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callback_manager: BaseCallbackManager = Field(
|
||||
default_factory=get_callback_manager, exclude=True
|
||||
)
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
@@ -37,15 +44,16 @@ class Chain(BaseModel, ABC):
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
@@ -82,15 +90,26 @@ class Chain(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
raise NotImplementedError("Async call not supported for this chain type.")
|
||||
|
||||
def __call__(
|
||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@@ -104,21 +123,31 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
self.callback_manager.on_chain_start(
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
outputs = self._call(inputs)
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
run_manager.on_chain_end(outputs)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
async def acall(
|
||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@@ -132,30 +161,24 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = await self._acall(inputs)
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
await run_manager.on_chain_end(outputs)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
def prep_outputs(
|
||||
@@ -195,11 +218,13 @@ class Chain(BaseModel, ABC):
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs) for inputs in input_list]
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
@@ -210,17 +235,17 @@ class Chain(BaseModel, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0])[self.output_keys[0]]
|
||||
return self(args[0], callbacks=callbacks)[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs)[self.output_keys[0]]
|
||||
return self(kwargs, callbacks=callbacks)[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
async def arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
@@ -231,10 +256,10 @@ class Chain(BaseModel, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (await self.acall(args[0]))[self.output_keys[0]]
|
||||
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return (await self.acall(kwargs))[self.output_keys[0]]
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -68,19 +72,33 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string asynchronously."""
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = self.combine_docs(docs, **other_keys)
|
||||
output, extra_return_dict = self.combine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = await self.acombine_docs(docs, **other_keys)
|
||||
output, extra_return_dict = await self.acombine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
@@ -108,10 +126,17 @@ class AnalyzeDocumentChain(Chain):
|
||||
"""
|
||||
return self.combine_docs_chain.output_keys
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
document = inputs[self.input_key]
|
||||
docs = self.text_splitter.create_documents([document])
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys[self.combine_docs_chain.input_key] = docs
|
||||
return self.combine_docs_chain(other_keys, return_only_outputs=True)
|
||||
return self.combine_docs_chain(
|
||||
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -129,7 +130,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
return self.combine_document_chain
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
@@ -138,12 +143,15 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""
|
||||
results = self.llm_chain.apply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return self._process_results(
|
||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return self._process_results(results, docs, token_max, **kwargs)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
@@ -152,15 +160,17 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""
|
||||
results = await self.llm_chain.aapply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return self._process_results(results, docs, **kwargs)
|
||||
return self._process_results(results, docs, callbacks=callbacks, **kwargs)
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
results: List[Dict],
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
question_result_key = self.llm_chain.output_key
|
||||
@@ -173,7 +183,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
return self._collapse_chain.run(input_documents=docs, **kwargs)
|
||||
return self._collapse_chain.run(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
@@ -191,7 +203,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
extra_return_dict = {"intermediate_steps": _results}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
output = self.combine_document_chain.run(input_documents=result_docs, **kwargs)
|
||||
output = self.combine_document_chain.run(
|
||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return output, extra_return_dict
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -89,19 +90,22 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
"""
|
||||
results = self.llm_chain.apply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
@@ -109,7 +113,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""
|
||||
results = await self.llm_chain.aapply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
@@ -85,29 +86,31 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = self.initial_llm_chain.predict(**inputs)
|
||||
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = self.refine_llm_chain.predict(**inputs)
|
||||
res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||
res = await self.initial_llm_chain.apredict(**inputs)
|
||||
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = await self.refine_llm_chain.apredict(**inputs)
|
||||
res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
@@ -75,19 +76,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(**inputs), {}
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return await self.llm_chain.apredict(**inputs), {}
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Chain for applying constitutional principles to the outputs of another chain."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class ConstitutionalChain(Chain):
|
||||
@@ -86,11 +87,16 @@ class ConstitutionalChain(Chain):
|
||||
"""Defines the output keys."""
|
||||
return ["output"]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
response = self.chain.run(**inputs)
|
||||
input_prompt = self.chain.prompt.format(**inputs)
|
||||
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
text="Initial response: " + response + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="yellow",
|
||||
@@ -103,6 +109,7 @@ class ConstitutionalChain(Chain):
|
||||
input_prompt=input_prompt,
|
||||
output_from_model=response,
|
||||
critique_request=constitutional_principle.critique_request,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
critique = self._parse_critique(
|
||||
output_string=raw_critique,
|
||||
@@ -116,22 +123,23 @@ class ConstitutionalChain(Chain):
|
||||
critique_request=constitutional_principle.critique_request,
|
||||
critique=critique,
|
||||
revision_request=constitutional_principle.revision_request,
|
||||
callbacks=_run_manager.get_child(),
|
||||
).strip()
|
||||
response = revision
|
||||
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
text=f"Applying {constitutional_principle.name}..." + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="green",
|
||||
)
|
||||
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
text="Critique: " + critique + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="blue",
|
||||
)
|
||||
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
text="Updated response: " + revision + "\n\n",
|
||||
verbose=self.verbose,
|
||||
color="yellow",
|
||||
|
||||
@@ -8,6 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
@@ -15,7 +20,7 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document
|
||||
from langchain.schema import BaseMessage, BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
# Depending on the memory type and configuration, the chat history format may differ.
|
||||
@@ -81,14 +86,20 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
chat_history_str = get_chat_history(inputs["chat_history"])
|
||||
|
||||
if chat_history_str:
|
||||
callbacks = _run_manager.get_child()
|
||||
new_question = self.question_generator.run(
|
||||
question=question, chat_history=chat_history_str
|
||||
question=question, chat_history=chat_history_str, callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
@@ -96,7 +107,9 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
answer = self.combine_docs_chain.run(input_documents=docs, **new_inputs)
|
||||
answer = self.combine_docs_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
||||
)
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
@@ -106,13 +119,19 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
chat_history_str = get_chat_history(inputs["chat_history"])
|
||||
if chat_history_str:
|
||||
callbacks = _run_manager.get_child()
|
||||
new_question = await self.question_generator.arun(
|
||||
question=question, chat_history=chat_history_str
|
||||
question=question, chat_history=chat_history_str, callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
@@ -120,7 +139,9 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
answer = await self.combine_docs_chain.arun(input_documents=docs, **new_inputs)
|
||||
answer = await self.combine_docs_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
||||
)
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -51,18 +52,25 @@ class GraphQAChain(Chain):
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
|
||||
|
||||
return cls(qa_chain=qa_chain, entity_extraction_chain=entity_chain, **kwargs)
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
entity_extraction_chain=entity_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Extract entities, look up info and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
entity_string = self.entity_extraction_chain.run(question)
|
||||
|
||||
self.callback_manager.on_text(
|
||||
"Entities Extracted:", end="\n", verbose=self.verbose
|
||||
)
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
entity_string, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
entities = get_entities(entity_string)
|
||||
@@ -70,9 +78,10 @@ class GraphQAChain(Chain):
|
||||
for entity in entities:
|
||||
triplets = self.graph.get_entity_knowledge(entity)
|
||||
context += "\n".join(triplets)
|
||||
self.callback_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
self.callback_manager.on_text(
|
||||
context, color="green", end="\n", verbose=self.verbose
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
result = self.qa_chain({"question": question, "context": context})
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
|
||||
@@ -4,11 +4,12 @@ https://arxiv.org/abs/2212.10496
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -57,18 +58,27 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
embeddings = self.embed_documents(documents)
|
||||
return self.combine_embeddings(embeddings)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Call the internal llm chain."""
|
||||
return self.llm_chain._call(inputs)
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
base_embeddings: Embeddings,
|
||||
prompt_key: str,
|
||||
**kwargs: Any,
|
||||
) -> HypotheticalDocumentEmbedder:
|
||||
"""Load and use LLMChain for a specific prompt key."""
|
||||
prompt = PROMPT_MAP[prompt_key]
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain)
|
||||
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -5,11 +5,19 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
@@ -53,21 +61,40 @@ class LLMChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.apply([inputs])[0]
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
return self.llm.generate_prompt(prompts, stop)
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
|
||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
async def agenerate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list)
|
||||
return await self.llm.agenerate_prompt(prompts, stop)
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
|
||||
def prep_prompts(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
@@ -79,7 +106,8 @@ class LLMChain(Chain):
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@@ -88,7 +116,9 @@ class LLMChain(Chain):
|
||||
return prompts, stop
|
||||
|
||||
async def aprep_prompts(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
@@ -100,12 +130,8 @@ class LLMChain(Chain):
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
_text, end="\n", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if run_manager:
|
||||
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@@ -113,15 +139,45 @@ class LLMChain(Chain):
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = self.generate(input_list)
|
||||
return self.create_outputs(response)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
async def aapply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = await self.agenerate(input_list)
|
||||
return self.create_outputs(response)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
await run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
@@ -131,13 +187,19 @@ class LLMChain(Chain):
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return (await self.aapply([inputs]))[0]
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = await self.agenerate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def predict(self, **kwargs: Any) -> str:
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
@@ -148,12 +210,13 @@ class LLMChain(Chain):
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs)[self.output_key]
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
async def apredict(self, **kwargs: Any) -> str:
|
||||
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
@@ -164,31 +227,33 @@ class LLMChain(Chain):
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return (await self.acall(kwargs))[self.output_key]
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
|
||||
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(**kwargs)
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
async def apredict_and_parse(
|
||||
self, **kwargs: Any
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call apredict and then parse the results."""
|
||||
result = await self.apredict(**kwargs)
|
||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list)
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
@@ -202,10 +267,10 @@ class LLMChain(Chain):
|
||||
return result
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = await self.aapply(input_list)
|
||||
result = await self.aapply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,47 +1,24 @@
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
||||
from langchain.schema import OutputParserException
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BashOutputParser(BaseOutputParser):
|
||||
"""Parser for bash output."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
if "```bash" in text:
|
||||
return self.get_code_blocks(text)
|
||||
else:
|
||||
raise OutputParserException(
|
||||
f"Failed to parse bash output. Got: {text}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_code_blocks(t: str) -> List[str]:
|
||||
"""Get multiple code blocks from the LLM result."""
|
||||
code_blocks: List[str] = []
|
||||
# Bash markdown code blocks
|
||||
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
|
||||
for match in pattern.finditer(t):
|
||||
matched = match.group(1).strip()
|
||||
if matched:
|
||||
code_blocks.extend(
|
||||
[line for line in matched.split("\n") if line.strip()]
|
||||
)
|
||||
|
||||
return code_blocks
|
||||
|
||||
|
||||
class LLMBashChain(Chain):
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
||||
|
||||
@@ -49,15 +26,16 @@ class LLMBashChain(Chain):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMBashChain, OpenAI
|
||||
llm_bash = LLMBashChain(llm=OpenAI())
|
||||
llm_bash = LLMBashChain.from_llm(OpenAI())
|
||||
"""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
output_parser: BaseOutputParser = Field(default_factory=BashOutputParser)
|
||||
"""[Deprecated]"""
|
||||
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
||||
|
||||
class Config:
|
||||
@@ -66,6 +44,26 @@ class LLMBashChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMBashChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain or using the from_llm class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@root_validator
|
||||
def validate_prompt(cls, values: Dict) -> Dict:
|
||||
if values["llm_chain"].prompt.output_parser is None:
|
||||
raise ValueError(
|
||||
"The prompt used by llm_chain is expected to have an output_parser."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
@@ -82,30 +80,33 @@ class LLMBashChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
|
||||
t = llm_executor.predict(question=inputs[self.input_key])
|
||||
self.callback_manager.on_text(t, color="green", verbose=self.verbose)
|
||||
t = self.llm_chain.predict(
|
||||
question=inputs[self.input_key], callbacks=_run_manager.get_child()
|
||||
)
|
||||
_run_manager.on_text(t, color="green", verbose=self.verbose)
|
||||
t = t.strip()
|
||||
try:
|
||||
command_list = self.output_parser.parse(t)
|
||||
command_list = self.llm_chain.prompt.output_parser.parse(t) # type: ignore[union-attr]
|
||||
except OutputParserException as e:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
_run_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
|
||||
if self.verbose:
|
||||
self.callback_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(command_list), color="yellow", verbose=self.verbose
|
||||
)
|
||||
|
||||
output = self.bash_process.run(command_list)
|
||||
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
return {self.output_key: output}
|
||||
|
||||
@property
|
||||
@@ -113,11 +114,11 @@ class LLMBashChain(Chain):
|
||||
return "llm_bash_chain"
|
||||
|
||||
@classmethod
|
||||
def from_bash_process(
|
||||
def from_llm(
|
||||
cls,
|
||||
bash_process: BashProcess,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> "LLMBashChain":
|
||||
"""Create a LLMBashChain from a BashProcess."""
|
||||
return cls(llm=llm, bash_process=bash_process, **kwargs)
|
||||
) -> LLMBashChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# flake8: noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
|
||||
|
||||
@@ -19,4 +25,36 @@ That is the format. Begin!
|
||||
|
||||
Question: {question}"""
|
||||
|
||||
PROMPT = PromptTemplate(input_variables=["question"], template=_PROMPT_TEMPLATE)
|
||||
|
||||
class BashOutputParser(BaseOutputParser):
|
||||
"""Parser for bash output."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
if "```bash" in text:
|
||||
return self.get_code_blocks(text)
|
||||
else:
|
||||
raise OutputParserException(
|
||||
f"Failed to parse bash output. Got: {text}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_code_blocks(t: str) -> List[str]:
|
||||
"""Get multiple code blocks from the LLM result."""
|
||||
code_blocks: List[str] = []
|
||||
# Bash markdown code blocks
|
||||
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
|
||||
for match in pattern.finditer(t):
|
||||
matched = match.group(1).strip()
|
||||
if matched:
|
||||
code_blocks.extend(
|
||||
[line for line in matched.split("\n") if line.strip()]
|
||||
)
|
||||
|
||||
return code_blocks
|
||||
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
output_parser=BashOutputParser(),
|
||||
)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Chain for question-answering with self-verification."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_checker.prompt import (
|
||||
@@ -18,6 +20,48 @@ from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
|
||||
def _load_question_to_checked_assertions_chain(
|
||||
llm: BaseLLM,
|
||||
create_draft_answer_prompt: PromptTemplate,
|
||||
list_assertions_prompt: PromptTemplate,
|
||||
check_assertions_prompt: PromptTemplate,
|
||||
revised_answer_prompt: PromptTemplate,
|
||||
) -> SequentialChain:
|
||||
create_draft_answer_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=create_draft_answer_prompt,
|
||||
output_key="statement",
|
||||
)
|
||||
list_assertions_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=list_assertions_prompt,
|
||||
output_key="assertions",
|
||||
)
|
||||
check_assertions_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
)
|
||||
revised_answer_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=revised_answer_prompt,
|
||||
output_key="revised_statement",
|
||||
)
|
||||
chains = [
|
||||
create_draft_answer_chain,
|
||||
list_assertions_chain,
|
||||
check_assertions_chain,
|
||||
revised_answer_chain,
|
||||
]
|
||||
question_to_checked_assertions_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=["question"],
|
||||
output_variables=["revised_statement"],
|
||||
verbose=True,
|
||||
)
|
||||
return question_to_checked_assertions_chain
|
||||
|
||||
|
||||
class LLMCheckerChain(Chain):
|
||||
"""Chain for question-answering with self-verification.
|
||||
|
||||
@@ -26,16 +70,21 @@ class LLMCheckerChain(Chain):
|
||||
|
||||
from langchain import OpenAI, LLMCheckerChain
|
||||
llm = OpenAI(temperature=0.7)
|
||||
checker_chain = LLMCheckerChain(llm=llm)
|
||||
checker_chain = LLMCheckerChain.from_llm(llm)
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
question_to_checked_assertions_chain: SequentialChain
|
||||
|
||||
llm: Optional[BaseLLM] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
|
||||
"""[Deprecated]"""
|
||||
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT
|
||||
"""Prompt to use when questioning the documents."""
|
||||
"""[Deprecated] Prompt to use when questioning the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@@ -45,6 +94,34 @@ class LLMCheckerChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
|
||||
"Please instantiate with question_to_checked_assertions_chain "
|
||||
"or using the from_llm class method."
|
||||
)
|
||||
if (
|
||||
"question_to_checked_assertions_chain" not in values
|
||||
and values["llm"] is not None
|
||||
):
|
||||
question_to_checked_assertions_chain = (
|
||||
_load_question_to_checked_assertions_chain(
|
||||
values["llm"],
|
||||
values.get(
|
||||
"create_draft_answer_prompt", CREATE_DRAFT_ANSWER_PROMPT
|
||||
),
|
||||
values.get("list_assertions_prompt", LIST_ASSERTIONS_PROMPT),
|
||||
values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT),
|
||||
values.get("revised_answer_prompt", REVISED_ANSWER_PROMPT),
|
||||
)
|
||||
)
|
||||
values[
|
||||
"question_to_checked_assertions_chain"
|
||||
] = question_to_checked_assertions_chain
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -61,43 +138,43 @@ class LLMCheckerChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
create_draft_answer_chain = LLMChain(
|
||||
llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement"
|
||||
output = self.question_to_checked_assertions_chain(
|
||||
{"question": question}, callbacks=_run_manager.get_child()
|
||||
)
|
||||
list_assertions_chain = LLMChain(
|
||||
llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions"
|
||||
)
|
||||
check_assertions_chain = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
)
|
||||
|
||||
revised_answer_chain = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.revised_answer_prompt,
|
||||
output_key="revised_statement",
|
||||
)
|
||||
|
||||
chains = [
|
||||
create_draft_answer_chain,
|
||||
list_assertions_chain,
|
||||
check_assertions_chain,
|
||||
revised_answer_chain,
|
||||
]
|
||||
|
||||
question_to_checked_assertions_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=["question"],
|
||||
output_variables=["revised_statement"],
|
||||
verbose=True,
|
||||
)
|
||||
output = question_to_checked_assertions_chain({"question": question})
|
||||
return {self.output_key: output["revised_statement"]}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_checker_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT,
|
||||
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT,
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
|
||||
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMCheckerChain:
|
||||
question_to_checked_assertions_chain = (
|
||||
_load_question_to_checked_assertions_chain(
|
||||
llm,
|
||||
create_draft_answer_prompt,
|
||||
list_assertions_prompt,
|
||||
check_assertions_prompt,
|
||||
revised_answer_prompt,
|
||||
)
|
||||
)
|
||||
return cls(
|
||||
question_to_checked_assertions_chain=question_to_checked_assertions_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, List
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numexpr
|
||||
from pydantic import Extra
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LLMMathChain(Chain):
|
||||
@@ -20,13 +27,14 @@ class LLMMathChain(Chain):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMMathChain, OpenAI
|
||||
llm_math = LLMMathChain(llm=OpenAI())
|
||||
llm_math = LLMMathChain.from_llm(OpenAI())
|
||||
"""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""Prompt to use to translate to python if neccessary."""
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
@@ -36,6 +44,19 @@ class LLMMathChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMMathChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
@@ -68,15 +89,17 @@ class LLMMathChain(Chain):
|
||||
# Remove any leading and trailing brackets from the output
|
||||
return re.sub(r"^\[|\]$", "", output)
|
||||
|
||||
def _process_llm_result(self, llm_output: str) -> Dict[str, str]:
|
||||
self.callback_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
def _process_llm_result(
|
||||
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
@@ -86,30 +109,19 @@ class LLMMathChain(Chain):
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
llm_output, color="green", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(
|
||||
llm_output, color="green", verbose=self.verbose
|
||||
)
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await self.callback_manager.on_text(
|
||||
output, color="yellow", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(
|
||||
output, color="yellow", verbose=self.verbose
|
||||
)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
@@ -119,31 +131,44 @@ class LLMMathChain(Chain):
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(
|
||||
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = self.llm_chain.predict(
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
llm_output = llm_executor.predict(
|
||||
question=inputs[self.input_key], stop=["```output"]
|
||||
)
|
||||
return self._process_llm_result(llm_output)
|
||||
return self._process_llm_result(llm_output, _run_manager)
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(
|
||||
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
inputs[self.input_key], verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
llm_output = await llm_executor.apredict(
|
||||
question=inputs[self.input_key], stop=["```output"]
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output)
|
||||
return await self._aprocess_llm_result(llm_output, _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_math_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMMathChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Chain that hits a URL and then uses an LLM to parse results."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
@@ -61,9 +62,14 @@ class LLMRequestsChain(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
url = inputs[self.input_key]
|
||||
@@ -71,7 +77,9 @@ class LLMRequestsChain(Chain):
|
||||
# extract the text from the html
|
||||
soup = BeautifulSoup(res, "html.parser")
|
||||
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
|
||||
result = self.llm_chain.predict(**other_keys)
|
||||
result = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
return {self.output_key: result}
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Chain for summarization with self-verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
@@ -27,6 +31,48 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
|
||||
)
|
||||
|
||||
|
||||
def _load_sequential_chain(
|
||||
llm: BaseLLM,
|
||||
create_assertions_prompt: PromptTemplate,
|
||||
check_assertions_prompt: PromptTemplate,
|
||||
revised_summary_prompt: PromptTemplate,
|
||||
are_all_true_prompt: PromptTemplate,
|
||||
verbose: bool = False,
|
||||
) -> SequentialChain:
|
||||
chain = SequentialChain(
|
||||
chains=[
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
prompt=create_assertions_prompt,
|
||||
output_key="assertions",
|
||||
verbose=verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
prompt=check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
verbose=verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
prompt=revised_summary_prompt,
|
||||
output_key="revised_summary",
|
||||
verbose=verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=llm,
|
||||
output_key="all_true",
|
||||
prompt=are_all_true_prompt,
|
||||
verbose=verbose,
|
||||
),
|
||||
],
|
||||
input_variables=["summary"],
|
||||
output_variables=["all_true", "revised_summary"],
|
||||
verbose=verbose,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
class LLMSummarizationCheckerChain(Chain):
|
||||
"""Chain for question-answering with self-verification.
|
||||
|
||||
@@ -35,16 +81,21 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
|
||||
from langchain import OpenAI, LLMSummarizationCheckerChain
|
||||
llm = OpenAI(temperature=0.0)
|
||||
checker_chain = LLMSummarizationCheckerChain(llm=llm)
|
||||
checker_chain = LLMSummarizationCheckerChain.from_llm(llm)
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
sequential_chain: SequentialChain
|
||||
llm: Optional[BaseLLM] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
|
||||
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||
"""[Deprecated]"""
|
||||
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT
|
||||
"""[Deprecated]"""
|
||||
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT
|
||||
"""[Deprecated]"""
|
||||
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
@@ -57,6 +108,25 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
|
||||
"deprecated. Please instantiate with"
|
||||
" sequential_chain argument or using the from_llm class method."
|
||||
)
|
||||
if "sequential_chain" not in values and values["llm"] is not None:
|
||||
values["sequential_chain"] = _load_sequential_chain(
|
||||
values["llm"],
|
||||
values.get("create_assertions_prompt", CREATE_ASSERTIONS_PROMPT),
|
||||
values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT),
|
||||
values.get("revised_summary_prompt", REVISED_SUMMARY_PROMPT),
|
||||
values.get("are_all_true_prompt", ARE_ALL_TRUE_PROMPT),
|
||||
verbose=values.get("verbose", False),
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -73,46 +143,21 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
all_true = False
|
||||
count = 0
|
||||
output = None
|
||||
original_input = inputs[self.input_key]
|
||||
chain_input = original_input
|
||||
|
||||
while not all_true and count < self.max_checks:
|
||||
chain = SequentialChain(
|
||||
chains=[
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.create_assertions_prompt,
|
||||
output_key="assertions",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.revised_summary_prompt,
|
||||
output_key="revised_summary",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
output_key="all_true",
|
||||
prompt=self.are_all_true_prompt,
|
||||
verbose=self.verbose,
|
||||
),
|
||||
],
|
||||
input_variables=["summary"],
|
||||
output_variables=["all_true", "revised_summary"],
|
||||
verbose=self.verbose,
|
||||
output = self.sequential_chain(
|
||||
{"summary": chain_input}, callbacks=_run_manager.get_child()
|
||||
)
|
||||
output = chain({"summary": chain_input})
|
||||
count += 1
|
||||
|
||||
if output["all_true"].strip() == "True":
|
||||
@@ -131,3 +176,24 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_summarization_checker_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT,
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
|
||||
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT,
|
||||
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> LLMSummarizationCheckerChain:
|
||||
chain = _load_sequential_chain(
|
||||
llm,
|
||||
create_assertions_prompt,
|
||||
check_assertions_prompt,
|
||||
revised_summary_prompt,
|
||||
are_all_true_prompt,
|
||||
verbose=verbose,
|
||||
)
|
||||
return cls(sequential_chain=chain, verbose=verbose, **kwargs)
|
||||
|
||||
@@ -5,10 +5,11 @@ then combines the results with another one.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -32,16 +33,26 @@ class MapReduceChain(Chain):
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate,
|
||||
text_splitter: TextSplitter,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceChain:
|
||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
|
||||
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks)
|
||||
combine_documents_chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain, combine_document_chain=reduce_chain
|
||||
llm_chain=llm_chain,
|
||||
combine_document_chain=reduce_chain,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain, text_splitter=text_splitter
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
text_splitter=text_splitter,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
class Config:
|
||||
@@ -66,9 +77,16 @@ class MapReduceChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
# Split the larger text into smaller chunks.
|
||||
texts = self.text_splitter.split_text(inputs[self.input_key])
|
||||
docs = [Document(page_content=text) for text in texts]
|
||||
outputs = self.combine_documents_chain.run(input_documents=docs)
|
||||
outputs = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child()
|
||||
)
|
||||
return {self.output_key: outputs}
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -84,7 +85,11 @@ class OpenAIModerationChain(Chain):
|
||||
return error_str
|
||||
return text
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
text = inputs[self.input_key]
|
||||
results = self.client.create(text)
|
||||
output = self._moderate(text, results["results"][0])
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Implement an LLM driven browser."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.natbot.prompt import PROMPT
|
||||
@@ -18,14 +20,15 @@ class NatBotChain(Chain):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import NatBotChain, OpenAI
|
||||
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
|
||||
from langchain import NatBotChain
|
||||
natbot = NatBotChain.from_default("Buy me a new hat.")
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
llm_chain: LLMChain
|
||||
objective: str
|
||||
"""Objective that NatBot is tasked with completing."""
|
||||
llm: Optional[BaseLLM] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
input_url_key: str = "url" #: :meta private:
|
||||
input_browser_content_key: str = "browser_content" #: :meta private:
|
||||
previous_command: str = "" #: :meta private:
|
||||
@@ -37,11 +40,24 @@ class NatBotChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an NatBotChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_default "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_default(cls, objective: str) -> NatBotChain:
|
||||
"""Load with default LLM."""
|
||||
"""Load with default LLMChain."""
|
||||
llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)
|
||||
return cls(llm=llm, objective=objective)
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
return cls(llm_chain=llm_chain, objective=objective)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -59,15 +75,20 @@ class NatBotChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
url = inputs[self.input_url_key]
|
||||
browser_content = inputs[self.input_browser_content_key]
|
||||
llm_cmd = llm_executor.predict(
|
||||
llm_cmd = self.llm_chain.predict(
|
||||
objective=self.objective,
|
||||
url=url[:100],
|
||||
previous_command=self.previous_command,
|
||||
browser_content=browser_content[:4500],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
llm_cmd = llm_cmd.strip()
|
||||
self.previous_command = llm_cmd
|
||||
|
||||
@@ -4,24 +4,29 @@ As in https://arxiv.org/pdf/2211.10435.pdf.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
|
||||
class PALChain(Chain):
|
||||
"""Implements Program-Aided Language Models."""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated]"""
|
||||
prompt: BasePromptTemplate = MATH_PROMPT
|
||||
"""[Deprecated]"""
|
||||
stop: str = "\n\n"
|
||||
get_answer_expr: str = "print(solution())"
|
||||
python_globals: Optional[Dict[str, Any]] = None
|
||||
@@ -35,6 +40,19 @@ class PALChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an PALChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the one of "
|
||||
"the class method constructors from_math_prompt, "
|
||||
"from_colored_object_prompt."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=MATH_PROMPT)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -54,12 +72,16 @@ class PALChain(Chain):
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
code = llm_chain.predict(stop=[self.stop], **inputs)
|
||||
self.callback_manager.on_text(
|
||||
code, color="green", end="\n", verbose=self.verbose
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
code = self.llm_chain.predict(
|
||||
stop=[self.stop], callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
_run_manager.on_text(code, color="green", end="\n", verbose=self.verbose)
|
||||
repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals)
|
||||
res = repl.run(code + f"\n{self.get_answer_expr}")
|
||||
output = {self.output_key: res.strip()}
|
||||
@@ -70,9 +92,9 @@ class PALChain(Chain):
|
||||
@classmethod
|
||||
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
|
||||
"""Load PAL from math prompt."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=MATH_PROMPT)
|
||||
return cls(
|
||||
llm=llm,
|
||||
prompt=MATH_PROMPT,
|
||||
llm_chain=llm_chain,
|
||||
stop="\n\n",
|
||||
get_answer_expr="print(solution())",
|
||||
**kwargs,
|
||||
@@ -83,9 +105,9 @@ class PALChain(Chain):
|
||||
cls, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> PALChain:
|
||||
"""Load PAL from colored object prompt."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=COLORED_OBJECT_PROMPT)
|
||||
return cls(
|
||||
llm=llm,
|
||||
prompt=COLORED_OBJECT_PROMPT,
|
||||
llm_chain=llm_chain,
|
||||
stop="\n\n\n",
|
||||
get_answer_expr="print(answer)",
|
||||
**kwargs,
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
|
||||
@@ -5,11 +5,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
@@ -45,11 +46,14 @@ class QAGenerationChain(Chain):
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, List]:
|
||||
docs = self.text_splitter.create_documents([inputs[self.input_key]])
|
||||
results = self.llm_chain.generate([{"text": d.page_content} for d in docs])
|
||||
results = self.llm_chain.generate(
|
||||
[{"text": d.page_content} for d in docs], run_manager=run_manager
|
||||
)
|
||||
qa = [json.loads(res[0].text) for res in results.generations]
|
||||
return {self.output_key: qa}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -8,6 +8,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -21,7 +26,6 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BaseQAWithSourcesChain(Chain, ABC):
|
||||
@@ -114,9 +118,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = self._get_docs(inputs)
|
||||
answer = self.combine_documents_chain.run(input_documents=docs, **inputs)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||
else:
|
||||
@@ -133,9 +144,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = await self._aget_docs(inputs)
|
||||
answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load question answering with sources chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
@@ -14,7 +15,6 @@ from langchain.chains.qa_with_sources import (
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Operator,
|
||||
@@ -20,7 +21,7 @@ from langchain.chains.query_constructor.prompt import (
|
||||
)
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.output_parsers.structured import parse_json_markdown
|
||||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load question answering chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -7,6 +7,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
@@ -14,7 +19,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -92,7 +97,11 @@ class BaseRetrievalQA(Chain):
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
@@ -104,11 +113,12 @@ class BaseRetrievalQA(Chain):
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = self._get_docs(question)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
@@ -120,7 +130,11 @@ class BaseRetrievalQA(Chain):
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
@@ -132,11 +146,12 @@ class BaseRetrievalQA(Chain):
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = await self._aget_docs(question)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_color_mapping
|
||||
|
||||
@@ -86,17 +90,31 @@ class SequentialChain(Chain):
|
||||
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
known_values = inputs.copy()
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
for i, chain in enumerate(self.chains):
|
||||
outputs = chain(known_values, return_only_outputs=True)
|
||||
callbacks = _run_manager.get_child()
|
||||
outputs = chain(known_values, return_only_outputs=True, callbacks=callbacks)
|
||||
known_values.update(outputs)
|
||||
return {k: known_values[k] for k in self.output_variables}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
known_values = inputs.copy()
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
for i, chain in enumerate(self.chains):
|
||||
outputs = await chain.acall(known_values, return_only_outputs=True)
|
||||
outputs = await chain.acall(
|
||||
known_values, return_only_outputs=True, callbacks=callbacks
|
||||
)
|
||||
known_values.update(outputs)
|
||||
return {k: known_values[k] for k in self.output_variables}
|
||||
|
||||
@@ -147,31 +165,37 @@ class SimpleSequentialChain(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
for i, chain in enumerate(self.chains):
|
||||
_input = chain.run(_input)
|
||||
_input = chain.run(_input, callbacks=_run_manager.get_child())
|
||||
if self.strip_outputs:
|
||||
_input = _input.strip()
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text(
|
||||
_input, color=color_mapping[str(i)], end="\n", verbose=self.verbose
|
||||
)
|
||||
return {self.output_key: _input}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
for i, chain in enumerate(self.chains):
|
||||
_input = await chain.arun(_input)
|
||||
_input = await chain.arun(_input, callbacks=callbacks)
|
||||
if self.strip_outputs:
|
||||
_input = _input.strip()
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
_input, color=color_mapping[str(i)], end="\n", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(
|
||||
_input, color=color_mapping[str(i)], end="\n", verbose=self.verbose
|
||||
)
|
||||
await _run_manager.on_text(
|
||||
_input, color=color_mapping[str(i)], end="\n", verbose=self.verbose
|
||||
)
|
||||
return {self.output_key: _input}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
"""Chain for interacting with SQL Database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
|
||||
@@ -21,15 +23,16 @@ class SQLDatabaseChain(Chain):
|
||||
|
||||
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
|
||||
db = SQLDatabase(...)
|
||||
db_chain = SQLDatabaseChain(llm=OpenAI(), database=db)
|
||||
db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)
|
||||
"""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
database: SQLDatabase = Field(exclude=True)
|
||||
"""SQL Database to connect to."""
|
||||
prompt: Optional[BasePromptTemplate] = None
|
||||
"""Prompt to use to translate natural language to SQL."""
|
||||
"""[Deprecated] Prompt to use to translate natural language to SQL."""
|
||||
top_k: int = 5
|
||||
"""Number of results to return from the query"""
|
||||
input_key: str = "query" #: :meta private:
|
||||
@@ -45,6 +48,22 @@ class SQLDatabaseChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an SQLDatabaseChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
database = values["database"]
|
||||
prompt = values.get("prompt") or SQL_PROMPTS.get(
|
||||
database.dialect, PROMPT
|
||||
)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -64,11 +83,14 @@ class SQLDatabaseChain(Chain):
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt = self.prompt or SQL_PROMPTS.get(self.database.dialect, PROMPT)
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
input_text = f"{inputs[self.input_key]}\nSQLQuery:"
|
||||
self.callback_manager.on_text(input_text, verbose=self.verbose)
|
||||
_run_manager.on_text(input_text, verbose=self.verbose)
|
||||
# If not present, then defaults to None which is all tables.
|
||||
table_names_to_use = inputs.get("table_names_to_use")
|
||||
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
||||
@@ -80,24 +102,26 @@ class SQLDatabaseChain(Chain):
|
||||
"stop": ["\nSQLResult:"],
|
||||
}
|
||||
intermediate_steps = []
|
||||
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||
sql_cmd = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **llm_inputs
|
||||
)
|
||||
intermediate_steps.append(sql_cmd)
|
||||
self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
||||
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
||||
result = self.database.run(sql_cmd)
|
||||
intermediate_steps.append(result)
|
||||
self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(result, color="yellow", verbose=self.verbose)
|
||||
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
|
||||
# If return direct, we just set the final result equal to the sql query
|
||||
if self.return_direct:
|
||||
final_result = result
|
||||
else:
|
||||
self.callback_manager.on_text("\nAnswer:", verbose=self.verbose)
|
||||
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
|
||||
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
||||
llm_inputs["input"] = input_text
|
||||
final_result = llm_chain.predict(**llm_inputs)
|
||||
self.callback_manager.on_text(
|
||||
final_result, color="green", verbose=self.verbose
|
||||
final_result = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **llm_inputs
|
||||
)
|
||||
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result["intermediate_steps"] = intermediate_steps
|
||||
@@ -107,6 +131,18 @@ class SQLDatabaseChain(Chain):
|
||||
def _chain_type(self) -> str:
|
||||
return "sql_database_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
db: SQLDatabase,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> SQLDatabaseChain:
|
||||
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, database=db, **kwargs)
|
||||
|
||||
|
||||
class SQLDatabaseSequentialChain(Chain):
|
||||
"""Chain for querying SQL database that is a sequential chain.
|
||||
@@ -118,6 +154,10 @@ class SQLDatabaseSequentialChain(Chain):
|
||||
This is useful in cases where the number of tables in the database is large.
|
||||
"""
|
||||
|
||||
decider_chain: LLMChain
|
||||
sql_chain: SQLDatabaseChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
@classmethod
|
||||
@@ -138,11 +178,6 @@ class SQLDatabaseSequentialChain(Chain):
|
||||
)
|
||||
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
|
||||
|
||||
decider_chain: LLMChain
|
||||
sql_chain: SQLDatabaseChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
@@ -162,25 +197,32 @@ class SQLDatabaseSequentialChain(Chain):
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_table_names = self.sql_chain.database.get_usable_table_names()
|
||||
table_names = ", ".join(_table_names)
|
||||
llm_inputs = {
|
||||
"query": inputs[self.input_key],
|
||||
"table_names": table_names,
|
||||
}
|
||||
table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs)
|
||||
self.callback_manager.on_text(
|
||||
"Table names to use:", end="\n", verbose=self.verbose
|
||||
table_names_to_use = self.decider_chain.predict_and_parse(
|
||||
callbacks=_run_manager.get_child(), **llm_inputs
|
||||
)
|
||||
self.callback_manager.on_text(
|
||||
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(table_names_to_use), color="yellow", verbose=self.verbose
|
||||
)
|
||||
new_inputs = {
|
||||
self.sql_chain.input_key: inputs[self.input_key],
|
||||
"table_names_to_use": table_names_to_use,
|
||||
}
|
||||
return self.sql_chain(new_inputs, return_only_outputs=True)
|
||||
return self.sql_chain(
|
||||
new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True
|
||||
)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load summarizing chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
@@ -8,7 +9,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Chain that runs an arbitrary python function."""
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
@@ -35,5 +36,9 @@ class TransformChain(Chain):
|
||||
"""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return self.transform(inputs)
|
||||
|
||||
@@ -2,6 +2,10 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.anthropic import _AnthropicCommon
|
||||
from langchain.schema import (
|
||||
@@ -85,7 +89,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
@@ -98,10 +105,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
for data in stream_resp:
|
||||
delta = data["completion"][len(completion) :]
|
||||
completion = data["completion"]
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
delta,
|
||||
)
|
||||
else:
|
||||
response = self.client.completion(**params)
|
||||
completion = response["completion"]
|
||||
@@ -109,7 +116,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
@@ -122,15 +132,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
async for data in stream_resp:
|
||||
delta = data["completion"][len(completion) :]
|
||||
completion = data["completion"]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
response = await self.client.acompletion(**params)
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, validator
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +35,19 @@ def _get_verbosity() -> bool:
|
||||
class BaseChatModel(BaseLanguageModel, ABC):
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -34,98 +55,130 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
results = [self._generate(m, stop=stop) for m in messages]
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
message_strings = [get_buffer_string(m) for m in messages]
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, message_strings
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
try:
|
||||
results = [
|
||||
self._generate(m, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._generate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
run_manager.on_llm_end(output)
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
results = await asyncio.gather(
|
||||
*[self._agenerate(m, stop=stop) for m in messages]
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
message_strings = [get_buffer_string(m) for m in messages]
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, message_strings
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._agenerate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await run_manager.on_llm_end(output)
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = self.generate(prompt_messages, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = await self.agenerate(prompt_messages, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
def __call__(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> BaseMessage:
|
||||
return self._generate(messages, stop=stop).generations[0].message
|
||||
generation = self.generate(
|
||||
[messages], stop=stop, callbacks=callbacks
|
||||
).generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
|
||||
result = self([HumanMessage(content=message)], stop=stop)
|
||||
@@ -134,15 +187,21 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
|
||||
class SimpleChatModel(BaseChatModel):
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop)
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Simpler interface."""
|
||||
|
||||
@@ -14,6 +14,10 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
@@ -242,7 +246,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
if self.streaming:
|
||||
@@ -255,10 +262,8 @@ class ChatOpenAI(BaseChatModel):
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
inner_completion += token
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{"content": inner_completion, "role": role}
|
||||
)
|
||||
@@ -287,7 +292,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
if self.streaming:
|
||||
@@ -300,16 +308,8 @@ class ChatOpenAI(BaseChatModel):
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
inner_completion += token
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{"content": inner_completion, "role": role}
|
||||
)
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import BaseMessage, ChatResult
|
||||
|
||||
@@ -33,13 +37,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
return_pl_id: Optional[bool] = False
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = super()._generate(messages, stop)
|
||||
generated_responses = super()._generate(messages, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
message_dicts, params = super()._create_message_dicts(messages, stop)
|
||||
for i, generation in enumerate(generated_responses.generations):
|
||||
@@ -67,13 +74,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
return generated_responses
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = await super()._agenerate(messages, stop)
|
||||
generated_responses = await super()._agenerate(messages, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
message_dicts, params = super()._create_message_dicts(messages, stop)
|
||||
for i, generation in enumerate(generated_responses.generations):
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""BabyAGI agent."""
|
||||
from collections import deque
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_creation import (
|
||||
TaskCreationChain,
|
||||
@@ -13,7 +16,6 @@ from langchain.experimental.autonomous_agents.baby_agi.task_execution import (
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import (
|
||||
TaskPrioritizationChain,
|
||||
)
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -112,7 +114,11 @@ class BabyAGI(Chain, BaseModel):
|
||||
objective=objective, context="\n".join(context), task=task
|
||||
)
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the agent."""
|
||||
objective = inputs["objective"]
|
||||
first_task = inputs.get("first_task", "Make a todo list")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskCreationChain(LLMChain):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskExecutionChain(LLMChain):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskPrioritizationChain(LLMChain):
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class GenerativeAgent(BaseModel):
|
||||
|
||||
@@ -3,9 +3,10 @@ import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema import BaseLanguageModel, BaseMemory, Document
|
||||
from langchain.schema import BaseMemory, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -106,7 +107,12 @@ class AI21(LLM):
|
||||
"""Return type of llm."""
|
||||
return "ai21"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to AI21's complete endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -200,7 +201,12 @@ class AlephAlpha(LLM):
|
||||
"""Return type of llm."""
|
||||
return "alpeh_alpha"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Aleph Alpha's completion endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tupl
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -158,7 +162,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
|
||||
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
r"""Call out to Anthropic's completion endpoint.
|
||||
|
||||
Args:
|
||||
@@ -187,9 +196,8 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
for data in stream_resp:
|
||||
delta = data["completion"][len(current_completion) :]
|
||||
current_completion = data["completion"]
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta, verbose=self.verbose, **data
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta, **data)
|
||||
return current_completion
|
||||
response = self.client.completion(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
@@ -198,7 +206,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
)
|
||||
return response["completion"]
|
||||
|
||||
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
if self.streaming:
|
||||
@@ -211,14 +224,8 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
async for data in stream_resp:
|
||||
delta = data["completion"][len(current_completion) :]
|
||||
current_completion = data["completion"]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
delta, verbose=self.verbose, **data
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta, verbose=self.verbose, **data
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta, **data)
|
||||
return current_completion
|
||||
response = await self.client.acompletion(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -80,7 +81,12 @@ class Banana(LLM):
|
||||
"""Return type of llm."""
|
||||
return "banana"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call to Banana endpoint."""
|
||||
try:
|
||||
import banana_dev as banana
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
"""Base interface for large language models to expose."""
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import Extra, Field, validator
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import Generation, LLMResult, PromptValue
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
@@ -59,7 +68,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -67,15 +77,16 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
@@ -90,30 +101,45 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return self.generate(prompt_strings, stop=stop)
|
||||
return self.generate(prompt_strings, stop=stop, callbacks=callbacks)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return await self.agenerate(prompt_strings, stop=stop)
|
||||
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
|
||||
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
@@ -124,21 +150,31 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
try:
|
||||
output = self._generate(prompts, stop=stop)
|
||||
output = (
|
||||
self._generate(prompts, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
run_manager.on_llm_end(output)
|
||||
return output
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@@ -149,15 +185,19 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
if len(missing_prompts) > 0:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, missing_prompts
|
||||
)
|
||||
try:
|
||||
new_results = self._generate(missing_prompts, stop=stop)
|
||||
new_results = (
|
||||
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._generate(missing_prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
||||
run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
@@ -167,36 +207,38 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
try:
|
||||
output = await self._agenerate(prompts, stop=stop)
|
||||
output = (
|
||||
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
await run_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@@ -207,32 +249,22 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
if len(missing_prompts) > 0:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
)
|
||||
try:
|
||||
new_results = await self._agenerate(missing_prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(
|
||||
new_results, verbose=self.verbose
|
||||
new_results = (
|
||||
await self._agenerate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(missing_prompts, stop=stop)
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
await run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
@@ -241,9 +273,15 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def __call__(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
return self.generate([prompt], stop=stop).generations[0][0].text
|
||||
return (
|
||||
self.generate([prompt], stop=stop, callbacks=callbacks)
|
||||
.generations[0][0]
|
||||
.text
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@@ -307,30 +345,56 @@ class LLM(BaseLLM):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
raise NotImplementedError("Async generation not implemented for this LLM.")
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# TODO: add caching here.
|
||||
generations = []
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
for prompt in prompts:
|
||||
text = self._call(prompt, stop=stop)
|
||||
text = (
|
||||
self._call(prompt, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(prompt, stop=stop)
|
||||
)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
generations = []
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
for prompt in prompts:
|
||||
text = await self._acall(prompt, stop=stop)
|
||||
text = (
|
||||
await self._acall(prompt, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(prompt, stop=stop)
|
||||
)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -81,7 +82,12 @@ class CerebriumAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "cerebriumai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call to CerebriumAI endpoint."""
|
||||
try:
|
||||
from cerebrium import model_api_request
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -100,7 +101,12 @@ class Cohere(LLM):
|
||||
"""Return type of llm."""
|
||||
return "cohere"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Cohere's generate endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -60,7 +61,12 @@ class DeepInfra(LLM):
|
||||
"""Return type of llm."""
|
||||
return "deepinfra"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to DeepInfra's inference API endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
@@ -15,7 +16,12 @@ class FakeListLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake-list"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
response = self.responses[self.i]
|
||||
self.i += 1
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -81,7 +82,12 @@ class ForefrontAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "forefrontai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to ForefrontAI's complete endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -130,7 +131,12 @@ class GooseAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "gooseai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call the GooseAI API."""
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -159,7 +160,12 @@ class GPT4All(LLM):
|
||||
"""Return the type of llm."""
|
||||
return "gpt4all"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
r"""Call out to GPT4All's generate method.
|
||||
|
||||
Args:
|
||||
@@ -175,14 +181,15 @@ class GPT4All(LLM):
|
||||
prompt = "Once upon a time, "
|
||||
response = model(prompt, n_predict=55)
|
||||
"""
|
||||
text_callback = partial(
|
||||
self.callback_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
text = self.client.generate(
|
||||
prompt,
|
||||
new_text_callback=text_callback,
|
||||
**self._default_params,
|
||||
)
|
||||
if run_manager:
|
||||
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
|
||||
text = self.client.generate(
|
||||
prompt,
|
||||
new_text_callback=text_callback,
|
||||
**self._default_params,
|
||||
)
|
||||
else:
|
||||
text = self.client.generate(prompt, **self._default_params)
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -88,7 +89,12 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Return type of llm."""
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -84,7 +85,12 @@ class HuggingFaceHub(LLM):
|
||||
"""Return type of llm."""
|
||||
return "huggingface_hub"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -146,7 +147,12 @@ class HuggingFacePipeline(LLM):
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
response = self.pipeline(prompt)
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text.
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -197,7 +198,12 @@ class LlamaCpp(LLM):
|
||||
|
||||
return params
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call the Llama model and return the output.
|
||||
|
||||
Args:
|
||||
@@ -219,7 +225,7 @@ class LlamaCpp(LLM):
|
||||
# method that yields as they are generated
|
||||
# and return the combined strings from the first choices's text:
|
||||
combined_text_output = ""
|
||||
for token in self.stream(prompt=prompt, stop=stop):
|
||||
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
|
||||
combined_text_output += token["choices"][0]["text"]
|
||||
return combined_text_output
|
||||
else:
|
||||
@@ -228,7 +234,10 @@ class LlamaCpp(LLM):
|
||||
return result["choices"][0]["text"]
|
||||
|
||||
def stream(
|
||||
self, prompt: str, stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Yields results objects as they are generated in real time.
|
||||
|
||||
@@ -268,7 +277,8 @@ class LlamaCpp(LLM):
|
||||
for chunk in result:
|
||||
token = chunk["choices"][0]["text"]
|
||||
log_probs = chunk["choices"][0].get("logprobs", None)
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield chunk
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
@@ -42,7 +43,12 @@ class ManifestWrapper(LLM):
|
||||
"""Return type of llm."""
|
||||
return "manifest"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to LLM through Manifest."""
|
||||
if stop is not None and len(stop) != 1:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -69,7 +70,12 @@ class Modal(LLM):
|
||||
"""Return type of llm."""
|
||||
return "modal"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call to Modal endpoint."""
|
||||
params = self.model_kwargs or {}
|
||||
response = requests.post(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user