mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
12 Commits
vwp/drafts
...
dev2049/nu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
795f823bb1 | ||
|
|
6cd653deb4 | ||
|
|
e953d2cf93 | ||
|
|
50668693d7 | ||
|
|
6fec15b6fb | ||
|
|
7bcdc66b99 | ||
|
|
4cdd19bd4e | ||
|
|
90cef7b53a | ||
|
|
675e27c136 | ||
|
|
fa4a4f2940 | ||
|
|
55c7964e4e | ||
|
|
3cc2ce6ac9 |
@@ -28,13 +28,15 @@ To stream the model's predictions, add in a CallbackManager.
|
||||
|
||||
```python
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
# There are many CallbackHandlers supported, such as
|
||||
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
|
||||
|
||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, verbose=True)
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler,
|
||||
verbose=True)
|
||||
|
||||
# Generate text. Tokens are streamed through the callback manager.
|
||||
model("Once upon a time, ")
|
||||
|
||||
@@ -17,33 +17,7 @@
|
||||
"source": [
|
||||
"LangChain provides a callback system that allows you to hook into the various stages of your LLM application. This is useful for logging, [monitoring](https://python.langchain.com/en/latest/tracing.html), [streaming](https://python.langchain.com/en/latest/modules/models/llms/examples/streaming_llm.html), and other tasks.\n",
|
||||
"\n",
|
||||
"You can subscribe to these events by using the `callback_manager` argument available throughout the API. A `CallbackManager` is an object that manages a list of `CallbackHandlers`. The `CallbackManager` will call the appropriate method on each handler when the event is triggered."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fdb72e8d-a02a-474d-96bf-f5759432afc8",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"class CallbackManager(BaseCallbackHandler):\n",
|
||||
" \"\"\"Base callback manager that can be used to handle callbacks from LangChain.\"\"\"\n",
|
||||
"\n",
|
||||
" def add_handler(self, callback: BaseCallbackHandler) -> None:\n",
|
||||
" \"\"\"Add a handler to the callback manager.\"\"\"\n",
|
||||
"\n",
|
||||
" def remove_handler(self, handler: BaseCallbackHandler) -> None:\n",
|
||||
" \"\"\"Remove a handler from the callback manager.\"\"\"\n",
|
||||
"\n",
|
||||
" def set_handler(self, handler: BaseCallbackHandler) -> None:\n",
|
||||
" \"\"\"Set handler as the only handler on the callback manager.\"\"\"\n",
|
||||
" self.set_handlers([handler])\n",
|
||||
"\n",
|
||||
" def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:\n",
|
||||
" \"\"\"Set handlers as the only handlers on the callback manager.\"\"\"\n",
|
||||
"```"
|
||||
"You can subscribe to these events by using the `callbacks` argument available throughout the API. This argument list of handler objects, which are expected to implement one or more of the methods described in the API docs."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -62,70 +36,57 @@
|
||||
},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"class BaseCallbackHandler(ABC):\n",
|
||||
"class BaseCallbackHandler:\n",
|
||||
" \"\"\"Base callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_llm_start(\n",
|
||||
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run when LLM starts running.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n",
|
||||
" \"\"\"Run on new LLM token. Only available when streaming is enabled.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:\n",
|
||||
" \"\"\"Run when LLM ends running.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_llm_error(\n",
|
||||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run when LLM errors.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_chain_start(\n",
|
||||
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run when chain starts running.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:\n",
|
||||
" \"\"\"Run when chain ends running.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_chain_error(\n",
|
||||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run when chain errors.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_tool_start(\n",
|
||||
" self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run when tool starts running.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_tool_end(self, output: str, **kwargs: Any) -> Any:\n",
|
||||
" \"\"\"Run when tool ends running.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_tool_error(\n",
|
||||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run when tool errors.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_text(self, text: str, **kwargs: Any) -> Any:\n",
|
||||
" \"\"\"Run on arbitrary text.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n",
|
||||
" \"\"\"Run on agent action.\"\"\"\n",
|
||||
"\n",
|
||||
" @abstractmethod\n",
|
||||
" def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:\n",
|
||||
" \"\"\"Run on agent end.\"\"\"\n",
|
||||
"```"
|
||||
@@ -136,9 +97,11 @@
|
||||
"id": "d3bf3304-43fb-47ad-ae50-0637a17018a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating and Using a Custom `CallbackHandler`\n",
|
||||
"## Using an existing handler\n",
|
||||
"\n",
|
||||
"By default, a shared CallbackManager with the StdOutCallbackHandler will be used by models, chains, agents, and tools. However, you can pass in your own CallbackManager with a custom CallbackHandler:"
|
||||
"LangChain provides a few built-in handlers that you can use to get started. These are available in the `langchain/callbacks` module. The most basic handler is the `StdOutCallbackHandler`, which simply logs all events to `stdout`. In the future we will add more default handlers to the library. \n",
|
||||
"\n",
|
||||
"**Note** when the `verbose` flag on the object is set to true, the `StdOutCallbackHandler` will be invoked even without being explicitly passed in."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -155,16 +118,16 @@
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"AgentAction(tool='Search', tool_input=\"US Open men's final 2019 winner\", log=' I need to find out who won the US Open men\\'s final in 2019 and then calculate his age raised to the 0.334 power.\\nAction: Search\\nAction Input: \"US Open men\\'s final 2019 winner\"')\n",
|
||||
"Rafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\n",
|
||||
"AgentAction(tool='Search', tool_input='Rafael Nadal age', log=' I need to find out the age of the winner\\nAction: Search\\nAction Input: \"Rafael Nadal age\"')\n",
|
||||
"36 years\n",
|
||||
"AgentAction(tool='Calculator', tool_input='36^0.334', log=' I now need to calculate his age raised to the 0.334 power\\nAction: Calculator\\nAction Input: 36^0.334')\n",
|
||||
"Answer: 3.3098250249682484\n",
|
||||
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n",
|
||||
"\n",
|
||||
" I now know the final answer\n",
|
||||
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@@ -172,7 +135,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\""
|
||||
"'\\n\\n3'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
@@ -181,108 +144,89 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import Any, Dict, List, Optional, Union\n",
|
||||
"\n",
|
||||
"from langchain.agents import initialize_agent, load_tools\n",
|
||||
"from langchain.agents import AgentType\n",
|
||||
"from langchain.callbacks.base import CallbackManager, BaseCallbackHandler\n",
|
||||
"from langchain.callbacks import StdOutCallbackHandler\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.schema import AgentAction, AgentFinish, LLMResult\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"class MyCustomCallbackHandler(BaseCallbackHandler):\n",
|
||||
" \"\"\"Custom CallbackHandler.\"\"\"\n",
|
||||
"handler = StdOutCallbackHandler()\n",
|
||||
"llm = OpenAI()\n",
|
||||
"prompt = PromptTemplate.from_template(\"1 + {number} = \")\n",
|
||||
"\n",
|
||||
" def on_llm_start(\n",
|
||||
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Print out the prompts.\"\"\"\n",
|
||||
" pass\n",
|
||||
"# First, let's explicitly set the StdOutCallbackHandler in `callbacks`\n",
|
||||
"chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])\n",
|
||||
"chain.run(number=2)\n",
|
||||
"\n",
|
||||
" def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Do nothing.\"\"\"\n",
|
||||
" pass\n",
|
||||
"# Then, let's use the `verbose` flag to achieve the same result\n",
|
||||
"chain = LLMChain(llm=llm, prompt=prompt, verbose=True)\n",
|
||||
"chain.run(number=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "389c8448-5283-49e3-8c04-dbe1522e202c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating a custom handler\n",
|
||||
"\n",
|
||||
" def on_llm_new_token(self, token: str, **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Do nothing.\"\"\"\n",
|
||||
" pass\n",
|
||||
"You can create a custom handler to set on the object as well. In the example below, we'll implement streaming with a custom handler."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "1b2e6588-0681-4cab-937a-7cc4790cea9a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"My custom handler, token: \n",
|
||||
"My custom handler, token: Why\n",
|
||||
"My custom handler, token: did\n",
|
||||
"My custom handler, token: the\n",
|
||||
"My custom handler, token: tomato\n",
|
||||
"My custom handler, token: turn\n",
|
||||
"My custom handler, token: red\n",
|
||||
"My custom handler, token: ?\n",
|
||||
"My custom handler, token: Because\n",
|
||||
"My custom handler, token: it\n",
|
||||
"My custom handler, token: saw\n",
|
||||
"My custom handler, token: the\n",
|
||||
"My custom handler, token: salad\n",
|
||||
"My custom handler, token: dressing\n",
|
||||
"My custom handler, token: !\n",
|
||||
"My custom handler, token: \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Why did the tomato turn red? Because it saw the salad dressing!', additional_kwargs={})"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.callbacks.base import BaseCallbackHandler\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"\n",
|
||||
" def on_llm_error(\n",
|
||||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Do nothing.\"\"\"\n",
|
||||
" pass\n",
|
||||
"class MyCustomHandler(BaseCallbackHandler):\n",
|
||||
" def on_llm_new_token(self, token: str, **kwargs) -> None:\n",
|
||||
" print(f\"My custom handler, token: {token}\")\n",
|
||||
"\n",
|
||||
" def on_chain_start(\n",
|
||||
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Print out that we are entering a chain.\"\"\"\n",
|
||||
" class_name = serialized[\"name\"]\n",
|
||||
" print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n",
|
||||
"# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n",
|
||||
"# Additionally, we pass in a list with our custom handler\n",
|
||||
"chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomHandler()])\n",
|
||||
"\n",
|
||||
" def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Print out that we finished a chain.\"\"\"\n",
|
||||
" print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n",
|
||||
"\n",
|
||||
" def on_chain_error(\n",
|
||||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Do nothing.\"\"\"\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def on_tool_start(\n",
|
||||
" self,\n",
|
||||
" serialized: Dict[str, Any],\n",
|
||||
" input_str: str,\n",
|
||||
" **kwargs: Any,\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Do nothing.\"\"\"\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def on_agent_action(\n",
|
||||
" self, action: AgentAction, color: Optional[str] = None, **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run on agent action.\"\"\"\n",
|
||||
" print(action)\n",
|
||||
"\n",
|
||||
" def on_tool_end(\n",
|
||||
" self,\n",
|
||||
" output: str,\n",
|
||||
" color: Optional[str] = None,\n",
|
||||
" observation_prefix: Optional[str] = None,\n",
|
||||
" llm_prefix: Optional[str] = None,\n",
|
||||
" **kwargs: Any,\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"If not the final action, print out observation.\"\"\"\n",
|
||||
" print(output)\n",
|
||||
"\n",
|
||||
" def on_tool_error(\n",
|
||||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Do nothing.\"\"\"\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def on_text(\n",
|
||||
" self,\n",
|
||||
" text: str,\n",
|
||||
" color: Optional[str] = None,\n",
|
||||
" end: str = \"\",\n",
|
||||
" **kwargs: Optional[str],\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Run when agent ends.\"\"\"\n",
|
||||
" print(text)\n",
|
||||
"\n",
|
||||
" def on_agent_finish(\n",
|
||||
" self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Run on agent end.\"\"\"\n",
|
||||
" print(finish.log)\n",
|
||||
"manager = CallbackManager([MyCustomCallbackHandler()])\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)\n",
|
||||
"tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, callback_manager=manager)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager\n",
|
||||
")\n",
|
||||
"agent.run(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")"
|
||||
"chat([HumanMessage(content=\"Tell me a joke\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -292,9 +236,11 @@
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"## Async Support\n",
|
||||
"## Async Callbacks\n",
|
||||
"\n",
|
||||
"If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` and `AsyncCallbackManager` to avoid blocking the runloop."
|
||||
"If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` to avoid blocking the runloop. \n",
|
||||
"\n",
|
||||
"**Advanced** if you use a sync `CallbackHandler` while using an async method to run your llm/chain/tool/agent, it will still work. However, under the hood, it will be called with [`run_in_executor`](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor) which can cause issues if your `CallbackHandler` is not thread-safe."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -310,58 +256,589 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"zzzz....\n",
|
||||
"Hi! I just woke up. Your llm is starting\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: \n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: Why\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: don\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: 't\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: scientists\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: trust\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: atoms\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: ?\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: Because\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: they\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: make\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: up\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: everything\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: !\n",
|
||||
"Sync handler being called in a `thread_pool_executor`: token: \n",
|
||||
"zzzz....\n",
|
||||
"Hi! I just woke up. Your llm is ending\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[ChatGeneration(text=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", generation_info=None, message=AIMessage(content=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", additional_kwargs={}))]], llm_output={'token_usage': {}, 'model_name': 'gpt-3.5-turbo'})"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"from typing import Any, Dict, List\n",
|
||||
"from langchain.schema import LLMResult\n",
|
||||
"from langchain.callbacks.base import AsyncCallbackHandler\n",
|
||||
"\n",
|
||||
"class MyCustomSyncHandler(BaseCallbackHandler):\n",
|
||||
" def on_llm_new_token(self, token: str, **kwargs) -> None:\n",
|
||||
" print(f\"Sync handler being called in a `thread_pool_executor`: token: {token}\")\n",
|
||||
"\n",
|
||||
"class MyCustomAsyncHandler(AsyncCallbackHandler):\n",
|
||||
" \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
|
||||
"\n",
|
||||
" async def on_llm_start(\n",
|
||||
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Run when chain starts running.\"\"\"\n",
|
||||
" print(\"zzzz....\")\n",
|
||||
" await asyncio.sleep(0.3)\n",
|
||||
" class_name = serialized[\"name\"]\n",
|
||||
" print(\"Hi! I just woke up. Your llm is starting\")\n",
|
||||
"\n",
|
||||
" async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Run when chain ends running.\"\"\"\n",
|
||||
" print(\"zzzz....\")\n",
|
||||
" await asyncio.sleep(0.3)\n",
|
||||
" print(\"Hi! I just woke up. Your llm is ending\")\n",
|
||||
"\n",
|
||||
"# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n",
|
||||
"# Additionally, we pass in a list with our custom handler\n",
|
||||
"chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomSyncHandler(), MyCustomAsyncHandler()])\n",
|
||||
"\n",
|
||||
"await chat.agenerate([[HumanMessage(content=\"Tell me a joke\")]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d26dbb34-fcc3-401c-a115-39c7620d2d65",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using multiple handlers, passing in handlers\n",
|
||||
"\n",
|
||||
"In the previous examples, we passed in callback handlers upon creation of an object by using `callbacks=`. In this case, the callbacks will be scoped to that particular object. \n",
|
||||
"\n",
|
||||
"However, in many cases, it is advantageous to pass in handlers instead when running the object. When we pass through `CallbackHandlers` using the `callbacks` keyword arg when executing an run, those callbacks will be issued by all nested objects involved in the execution. For example, when a handler is passed through to an `Agent`, it will be used for all callbacks related to the agent and all the objects involved in the agent's execution, in this case, the `Tools`, `LLMChain`, and `LLM`.\n",
|
||||
"\n",
|
||||
"This prevents us from having to manually attach the handlers to each individual nested object."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "8eec8756-1828-45cb-9699-38ac8543a150",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"on_chain_start AgentExecutor\n",
|
||||
"on_chain_start LLMChain\n",
|
||||
"on_llm_start OpenAI\n",
|
||||
"on_llm_start (I'm the second handler!!) OpenAI\n",
|
||||
"on_new_token I\n",
|
||||
"on_new_token need\n",
|
||||
"on_new_token to\n",
|
||||
"on_new_token use\n",
|
||||
"on_new_token a\n",
|
||||
"on_new_token calculator\n",
|
||||
"on_new_token to\n",
|
||||
"on_new_token solve\n",
|
||||
"on_new_token this\n",
|
||||
"on_new_token .\n",
|
||||
"on_new_token \n",
|
||||
"Action\n",
|
||||
"on_new_token :\n",
|
||||
"on_new_token Calculator\n",
|
||||
"on_new_token \n",
|
||||
"Action\n",
|
||||
"on_new_token Input\n",
|
||||
"on_new_token :\n",
|
||||
"on_new_token 2\n",
|
||||
"on_new_token ^\n",
|
||||
"on_new_token 0\n",
|
||||
"on_new_token .\n",
|
||||
"on_new_token 235\n",
|
||||
"on_new_token \n",
|
||||
"on_agent_action AgentAction(tool='Calculator', tool_input='2^0.235', log=' I need to use a calculator to solve this.\\nAction: Calculator\\nAction Input: 2^0.235')\n",
|
||||
"on_tool_start Calculator\n",
|
||||
"on_chain_start LLMMathChain\n",
|
||||
"on_chain_start LLMChain\n",
|
||||
"on_llm_start OpenAI\n",
|
||||
"on_llm_start (I'm the second handler!!) OpenAI\n",
|
||||
"on_new_token \n",
|
||||
"\n",
|
||||
"on_new_token ```text\n",
|
||||
"on_new_token \n",
|
||||
"\n",
|
||||
"on_new_token 2\n",
|
||||
"on_new_token **\n",
|
||||
"on_new_token 0\n",
|
||||
"on_new_token .\n",
|
||||
"on_new_token 235\n",
|
||||
"on_new_token \n",
|
||||
"\n",
|
||||
"on_new_token ```\n",
|
||||
"\n",
|
||||
"on_new_token ...\n",
|
||||
"on_new_token num\n",
|
||||
"on_new_token expr\n",
|
||||
"on_new_token .\n",
|
||||
"on_new_token evaluate\n",
|
||||
"on_new_token (\"\n",
|
||||
"on_new_token 2\n",
|
||||
"on_new_token **\n",
|
||||
"on_new_token 0\n",
|
||||
"on_new_token .\n",
|
||||
"on_new_token 235\n",
|
||||
"on_new_token \")\n",
|
||||
"on_new_token ...\n",
|
||||
"on_new_token \n",
|
||||
"\n",
|
||||
"on_new_token \n",
|
||||
"on_chain_start LLMChain\n",
|
||||
"on_llm_start OpenAI\n",
|
||||
"on_llm_start (I'm the second handler!!) OpenAI\n",
|
||||
"on_new_token I\n",
|
||||
"on_new_token now\n",
|
||||
"on_new_token know\n",
|
||||
"on_new_token the\n",
|
||||
"on_new_token final\n",
|
||||
"on_new_token answer\n",
|
||||
"on_new_token .\n",
|
||||
"on_new_token \n",
|
||||
"Final\n",
|
||||
"on_new_token Answer\n",
|
||||
"on_new_token :\n",
|
||||
"on_new_token 1\n",
|
||||
"on_new_token .\n",
|
||||
"on_new_token 17\n",
|
||||
"on_new_token 690\n",
|
||||
"on_new_token 67\n",
|
||||
"on_new_token 372\n",
|
||||
"on_new_token 187\n",
|
||||
"on_new_token 674\n",
|
||||
"on_new_token \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'1.1769067372187674'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import Dict, Union, Any, List\n",
|
||||
"\n",
|
||||
"from langchain.callbacks.base import BaseCallbackHandler\n",
|
||||
"from langchain.schema import AgentAction\n",
|
||||
"from langchain.agents import AgentType, initialize_agent, load_tools\n",
|
||||
"from langchain.callbacks import tracing_enabled\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"# First, define custom callback handler implementations\n",
|
||||
"class MyCustomHandlerOne(BaseCallbackHandler):\n",
|
||||
" def on_llm_start(\n",
|
||||
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" print(f\"on_llm_start {serialized['name']}\")\n",
|
||||
"\n",
|
||||
" def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n",
|
||||
" print(f\"on_new_token {token}\")\n",
|
||||
"\n",
|
||||
" def on_llm_error(\n",
|
||||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" \"\"\"Run when LLM errors.\"\"\"\n",
|
||||
"\n",
|
||||
" def on_chain_start(\n",
|
||||
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" print(f\"on_chain_start {serialized['name']}\")\n",
|
||||
"\n",
|
||||
" def on_tool_start(\n",
|
||||
" self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" print(f\"on_tool_start {serialized['name']}\")\n",
|
||||
"\n",
|
||||
" def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n",
|
||||
" print(f\"on_agent_action {action}\")\n",
|
||||
"\n",
|
||||
"class MyCustomHandlerTwo(BaseCallbackHandler):\n",
|
||||
" def on_llm_start(\n",
|
||||
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
|
||||
" ) -> Any:\n",
|
||||
" print(f\"on_llm_start (I'm the second handler!!) {serialized['name']}\")\n",
|
||||
"\n",
|
||||
"# Instantiate the handlers\n",
|
||||
"handler1 = MyCustomHandlerOne()\n",
|
||||
"handler2 = MyCustomHandlerTwo()\n",
|
||||
"\n",
|
||||
"# Setup the agent. Only the `llm` will issue callbacks for handler2\n",
|
||||
"llm = OpenAI(temperature=0, streaming=True, callbacks=[handler2])\n",
|
||||
"tools = load_tools([\"llm-math\"], llm=llm)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Callbacks for handler1 will be issued by every object involved in the \n",
|
||||
"# Agent execution (llm, llmchain, tool, agent executor)\n",
|
||||
"agent.run(\"What is 2 raised to the 0.235 power?\", callbacks=[handler1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "32b29135-f852-4492-88ed-547275c72c53",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tracing and Token Counting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fbb606d6-2863-46c5-8347-9f0bdb3805bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Tracing and token counting are two capabilities we provide which are built on our callbacks mechanism."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f62cd10c-494c-47d6-aa98-6e926cb9c456",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tracing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d5a74b3f-3769-4a4f-99c7-b6a3b20a94e2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"There are two recommended ways to trace your LangChains:\n",
|
||||
"\n",
|
||||
"1. Setting the `LANGCHAIN_TRACING` environment variable to `\"true\"`. \n",
|
||||
"2. Using a context manager `with tracing_enabled()` to trace a particular block of code.\n",
|
||||
"\n",
|
||||
"**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "f164dfd5-d987-4b6a-a7c8-019c651ce47f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain.agents import AgentType, initialize_agent, load_tools\n",
|
||||
"from langchain.callbacks import tracing_enabled\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"# To run the code, make sure to set OPENAI_API_KEY and SERPAPI_API_KEY\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"questions = [\n",
|
||||
" \"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\",\n",
|
||||
" \"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\",\n",
|
||||
" \"Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?\",\n",
|
||||
" \"Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?\",\n",
|
||||
" \"Who is Beyonce's husband? What is his age raised to the 0.19 power?\",\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6be7777e-ec1d-438f-ae33-3a93c45f808e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"zzzz....\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Rafael Nadal age\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 36^0.334\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Harry Styles age\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 29^0.23\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
|
||||
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"from aiohttp import ClientSession\n",
|
||||
"os.environ[\"LANGCHAIN_TRACING\"] = \"true\"\n",
|
||||
"\n",
|
||||
"from langchain.callbacks.base import AsyncCallbackHandler, AsyncCallbackManager\n",
|
||||
"\n",
|
||||
"class MyCustomAsyncCallbackHandler(AsyncCallbackHandler):\n",
|
||||
" \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
|
||||
"\n",
|
||||
" async def on_chain_start(\n",
|
||||
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Run when chain starts running.\"\"\"\n",
|
||||
" print(\"zzzz....\")\n",
|
||||
" await asyncio.sleep(0.5)\n",
|
||||
" class_name = serialized[\"name\"]\n",
|
||||
" print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n",
|
||||
"\n",
|
||||
" async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Run when chain ends running.\"\"\"\n",
|
||||
" print(\"zzzz....\")\n",
|
||||
" await asyncio.sleep(0.5)\n",
|
||||
" print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n",
|
||||
"\n",
|
||||
"manager = AsyncCallbackManager([MyCustomAsyncCallbackHandler()])\n",
|
||||
"\n",
|
||||
"# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n",
|
||||
"# but you must manually close the client session at the end of your program/event loop\n",
|
||||
"aiosession = ClientSession()\n",
|
||||
"llm = OpenAI(temperature=0, callback_manager=manager)\n",
|
||||
"async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n",
|
||||
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
|
||||
"await async_agent.arun(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")\n",
|
||||
"await aiosession.close()"
|
||||
"# Both of the agent runs will be traced because the environment variable is set\n",
|
||||
"agent.run(questions[0])\n",
|
||||
"with tracing_enabled() as session:\n",
|
||||
" assert session\n",
|
||||
" agent.run(questions[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "86be6304-e433-4048-880c-a92a73244407",
|
||||
"execution_count": 10,
|
||||
"id": "a6fd6026-dc1e-4d48-893d-3592539c7828",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Rafael Nadal age\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 36^0.334\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Harry Styles age\"\u001b[0m\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 29^0.23\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
|
||||
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Now, we unset the environment variable and use a context manager.\n",
|
||||
"\n",
|
||||
"if \"LANGCHAIN_TRACING\" in os.environ:\n",
|
||||
" del os.environ[\"LANGCHAIN_TRACING\"]\n",
|
||||
"\n",
|
||||
"# here, we are writing traces to \"my_test_session\"\n",
|
||||
"with tracing_enabled(\"my_test_session\") as session:\n",
|
||||
" assert session\n",
|
||||
" agent.run(questions[0]) # this should be traced\n",
|
||||
"\n",
|
||||
"agent.run(questions[1]) # this should not be traced"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "9383a351-4983-44e9-abd7-ef942e1c65c4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\u001b[33;1m\u001b[1;3mLewis Hamilton has won 103 Grands Prix during his career. He won 21 races with McLaren and has won 82 with Mercedes. Lewis Hamilton holds the record for the ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[33;1m\u001b[1;3m36 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Harry Styles age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Lewis Hamilton's age\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Lewis Hamilton Age\"\u001b[0m\u001b[33;1m\u001b[1;3m29 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 36^0.334\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 29^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\u001b[33;1m\u001b[1;3m38 years\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I now need to calculate 38 raised to the 0.23 power\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 38^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.3086081644669734\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# The context manager is concurrency safe:\n",
|
||||
"if \"LANGCHAIN_TRACING\" in os.environ:\n",
|
||||
" del os.environ[\"LANGCHAIN_TRACING\"]\n",
|
||||
"\n",
|
||||
"# start a background task\n",
|
||||
"task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced\n",
|
||||
"with tracing_enabled() as session:\n",
|
||||
" assert session\n",
|
||||
" tasks = [agent.arun(q) for q in questions[1:3]] # these should be traced\n",
|
||||
" await asyncio.gather(*tasks)\n",
|
||||
"\n",
|
||||
"await task"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "254fef1b-6b6e-4352-9cf4-363fba895ac7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Token Counting\n",
|
||||
"LangChain offers a context manager that allows you to count tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "5c3e0b89-2c5e-4036-bdf2-fb6b750e360c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from langchain.callbacks import get_openai_callback\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"with get_openai_callback() as cb:\n",
|
||||
" llm(\"What is the square root of 4?\")\n",
|
||||
"\n",
|
||||
"total_tokens = cb.total_tokens\n",
|
||||
"assert total_tokens > 0\n",
|
||||
"\n",
|
||||
"with get_openai_callback() as cb:\n",
|
||||
" llm(\"What is the square root of 4?\")\n",
|
||||
" llm(\"What is the square root of 4?\")\n",
|
||||
"\n",
|
||||
"assert cb.total_tokens == total_tokens * 2\n",
|
||||
"\n",
|
||||
"# You can kick off concurrent runs from within the context manager\n",
|
||||
"with get_openai_callback() as cb:\n",
|
||||
" await asyncio.gather(\n",
|
||||
" *[llm.agenerate([\"What is the square root of 4?\"]) for _ in range(3)]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert cb.total_tokens == total_tokens * 3\n",
|
||||
"\n",
|
||||
"# The context manager is concurrency safe\n",
|
||||
"task = asyncio.create_task(llm.agenerate([\"What is the square root of 4?\"]))\n",
|
||||
"with get_openai_callback() as cb:\n",
|
||||
" await llm.agenerate([\"What is the square root of 4?\"])\n",
|
||||
"\n",
|
||||
"await task\n",
|
||||
"assert cb.total_tokens == total_tokens"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -5,11 +5,6 @@ from typing import Optional
|
||||
|
||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.cache import BaseCache
|
||||
from langchain.callbacks import (
|
||||
set_default_callback_manager,
|
||||
set_handler,
|
||||
set_tracing_callback_manager,
|
||||
)
|
||||
from langchain.chains import (
|
||||
ConversationChain,
|
||||
LLMBashChain,
|
||||
@@ -65,7 +60,6 @@ del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
verbose: bool = False
|
||||
llm_cache: Optional[BaseCache] = None
|
||||
set_default_callback_manager()
|
||||
|
||||
# For backwards compatibility
|
||||
SerpAPIChain = SerpAPIWrapper
|
||||
@@ -115,7 +109,5 @@ __all__ = [
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"QAWithSourcesChain",
|
||||
"PALChain",
|
||||
"set_handler",
|
||||
"set_tracing_callback_manager",
|
||||
"LlamaCpp",
|
||||
]
|
||||
|
||||
@@ -13,7 +13,13 @@ import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
@@ -23,7 +29,6 @@ from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
)
|
||||
@@ -46,13 +51,17 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -61,13 +70,17 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -170,13 +183,17 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -185,13 +202,17 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
@@ -285,38 +306,52 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = self.llm_chain.run(
|
||||
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = await self.llm_chain.arun(
|
||||
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
@@ -368,37 +403,45 @@ class Agent(BaseSingleActionAgent):
|
||||
return thoughts
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(**full_inputs)
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(**full_inputs)
|
||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
def get_full_inputs(
|
||||
@@ -632,24 +675,27 @@ class AgentExecutor(Chain):
|
||||
|
||||
return True
|
||||
|
||||
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
||||
self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
def _return(
|
||||
self,
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if run_manager:
|
||||
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
|
||||
async def _areturn(
|
||||
self, output: AgentFinish, intermediate_steps: list
|
||||
self,
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_finish(
|
||||
if run_manager:
|
||||
await run_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
)
|
||||
final_output = output.return_values
|
||||
@@ -663,13 +709,18 @@ class AgentExecutor(Chain):
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan(intermediate_steps, **inputs)
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
@@ -680,9 +731,8 @@ class AgentExecutor(Chain):
|
||||
actions = output
|
||||
result = []
|
||||
for agent_action in actions:
|
||||
self.callback_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
@@ -696,6 +746,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -704,6 +755,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
result.append((agent_action, observation))
|
||||
@@ -715,13 +767,18 @@ class AgentExecutor(Chain):
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
# Call the LLM to see what to do.
|
||||
output = await self.agent.aplan(intermediate_steps, **inputs)
|
||||
output = await self.agent.aplan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
@@ -734,12 +791,8 @@ class AgentExecutor(Chain):
|
||||
async def _aperform_agent_action(
|
||||
agent_action: AgentAction,
|
||||
) -> Tuple[AgentAction, str]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_agent_action(
|
||||
if run_manager:
|
||||
await run_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
@@ -755,6 +808,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -763,6 +817,7 @@ class AgentExecutor(Chain):
|
||||
agent_action.tool,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
return agent_action, observation
|
||||
@@ -774,7 +829,11 @@ class AgentExecutor(Chain):
|
||||
|
||||
return list(result)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@@ -790,10 +849,16 @@ class AgentExecutor(Chain):
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = self._take_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return self._return(next_step_output, intermediate_steps)
|
||||
return self._return(
|
||||
next_step_output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
|
||||
intermediate_steps.extend(next_step_output)
|
||||
if len(next_step_output) == 1:
|
||||
@@ -809,7 +874,11 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
return self._return(output, intermediate_steps)
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@@ -827,7 +896,11 @@ class AgentExecutor(Chain):
|
||||
try:
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = await self._atake_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return await self._areturn(next_step_output, intermediate_steps)
|
||||
@@ -845,7 +918,9 @@ class AgentExecutor(Chain):
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
return await self._areturn(
|
||||
output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
except TimeoutError:
|
||||
# stop early when interrupted by the async timeout
|
||||
output = self.agent.return_stopped_response(
|
||||
|
||||
@@ -26,12 +26,12 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
|
||||
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.memory import ReadOnlySharedMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.requests.tool import BaseRequestsTool
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.chat.output_parser import ChatOutputParser
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -13,7 +14,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseLanguageModel
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
|
||||
SUFFIX,
|
||||
TEMPLATE_TOOL_RESPONSE,
|
||||
)
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -24,7 +25,6 @@ from langchain.prompts.chat import (
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
HumanMessage,
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any, Optional, Sequence
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -103,8 +103,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool:
|
||||
return Tool(
|
||||
name="Calculator",
|
||||
description="Useful for when you need to answer questions about math.",
|
||||
func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run,
|
||||
coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun,
|
||||
func=LLMMathChain(llm=llm).run,
|
||||
coroutine=LLMMathChain(llm=llm).arun,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,11 @@ from typing import Any, Awaitable, Callable, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, validate_arguments
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@@ -26,14 +31,42 @@ class Tool(BaseTool):
|
||||
valid_keys = signature(self.func).parameters
|
||||
return {k: schema[k] for k in valid_keys}
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.func(*args, **kwargs)
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
new_argument_supported = signature(self.coroutine).parameters.get("callbacks")
|
||||
if self.coroutine:
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support async")
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
|
||||
55
langchain/base_language.py
Normal file
55
langchain/base_language.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Base class for all language models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the message."""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
@@ -1,80 +1,27 @@
|
||||
"""Callback handlers that allow listening to events in LangChain."""
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
from typing import Generator
|
||||
|
||||
from langchain.callbacks.aim_callback import AimCallbackHandler
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
|
||||
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
get_openai_callback,
|
||||
tracing_enabled,
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.tracers import SharedLangChainTracer
|
||||
from langchain.callbacks.tracers import LangChainTracer
|
||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||
|
||||
|
||||
def get_callback_manager() -> BaseCallbackManager:
|
||||
"""Return the shared callback manager."""
|
||||
return SharedCallbackManager()
|
||||
|
||||
|
||||
def set_handler(handler: BaseCallbackHandler) -> None:
|
||||
"""Set handler."""
|
||||
callback = get_callback_manager()
|
||||
callback.set_handler(handler)
|
||||
|
||||
|
||||
def set_default_callback_manager() -> None:
|
||||
"""Set default callback manager."""
|
||||
default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
|
||||
if default_handler == "stdout":
|
||||
set_handler(StdOutCallbackHandler())
|
||||
elif default_handler == "langchain":
|
||||
session = os.environ.get("LANGCHAIN_SESSION")
|
||||
set_tracing_callback_manager(session)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LANGCHAIN_HANDLER should be one of `stdout` "
|
||||
f"or `langchain`, got {default_handler}"
|
||||
)
|
||||
|
||||
|
||||
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
|
||||
"""Set tracing callback manager."""
|
||||
handler = SharedLangChainTracer()
|
||||
callback = get_callback_manager()
|
||||
callback.set_handlers([handler, StdOutCallbackHandler()])
|
||||
if session_name is None:
|
||||
handler.load_default_session()
|
||||
else:
|
||||
try:
|
||||
handler.load_session(session_name)
|
||||
except Exception:
|
||||
raise ValueError(f"session {session_name} not found")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
handler = OpenAICallbackHandler()
|
||||
manager = get_callback_manager()
|
||||
manager.add_handler(handler)
|
||||
yield handler
|
||||
manager.remove_handler(handler)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CallbackManager",
|
||||
"AsyncCallbackManager",
|
||||
"OpenAICallbackHandler",
|
||||
"SharedCallbackManager",
|
||||
"StdOutCallbackHandler",
|
||||
"AimCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
@@ -82,8 +29,5 @@ __all__ = [
|
||||
"CometCallbackHandler",
|
||||
"AsyncIteratorCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"set_tracing_callback_manager",
|
||||
"set_default_callback_manager",
|
||||
"set_handler",
|
||||
"get_callback_manager",
|
||||
"tracing_enabled",
|
||||
]
|
||||
|
||||
@@ -1,19 +1,173 @@
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class BaseCallbackHandler(ABC):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
class LLMManagerMixin:
|
||||
"""Mixin for LLM callbacks."""
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return False
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class ToolManagerMixin:
|
||||
"""Mixin for tool callbacks."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
|
||||
class CallbackManagerMixin:
|
||||
"""Mixin for callback manager."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
|
||||
class RunManagerMixin:
|
||||
"""Mixin for run manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
@@ -30,480 +184,197 @@ class BaseCallbackHandler(ABC):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> Any:
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
async def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[str] = parent_run_id
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
if inherit:
|
||||
self.inheritable_handlers.append(handler)
|
||||
|
||||
@abstractmethod
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
self.add_handler(handler, inherit=inherit)
|
||||
|
||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Set handler as the only handler on the callback manager."""
|
||||
self.set_handlers([handler])
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
@abstractmethod
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_end(response)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_end(outputs)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_chain:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_chain_error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_tool_error(error)
|
||||
|
||||
def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
for handler in self.handlers:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_agent:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_agent_finish(finish, **kwargs)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
async def _handle_event_for_handler(
|
||||
handler: BaseCallbackHandler,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
verbose: bool,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
if verbose or handler.always_verbose:
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(event, *args, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the handler is async."""
|
||||
return True
|
||||
|
||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
async def _handle_event(
|
||||
self,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
verbose: bool,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Generic event handler for AsyncCallbackManager."""
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_handle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
|
||||
)
|
||||
for handler in self.handlers
|
||||
)
|
||||
def __copy__(self) -> "BaseCallbackManager":
|
||||
return self.__class__(
|
||||
self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
await self._handle_event(
|
||||
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
|
||||
def __deepcopy__(self, memo: dict) -> "BaseCallbackManager":
|
||||
return self.__class__(
|
||||
[copy.deepcopy(handler, memo) for handler in self.handlers],
|
||||
[copy.deepcopy(handler, memo) for handler in self.inheritable_handlers],
|
||||
self.parent_run_id,
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self, token: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
await self._handle_event(
|
||||
"on_llm_new_token", "ignore_llm", verbose, token, **kwargs
|
||||
)
|
||||
|
||||
async def on_llm_end(
|
||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
await self._handle_event(
|
||||
"on_llm_end", "ignore_llm", verbose, response, **kwargs
|
||||
)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
await self._handle_event(
|
||||
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
|
||||
)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
await self._handle_event(
|
||||
"on_chain_end", "ignore_chain", verbose, outputs, **kwargs
|
||||
)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
await self._handle_event(
|
||||
"on_chain_error", "ignore_chain", verbose, error, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
await self._handle_event(
|
||||
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self, output: str, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
await self._handle_event(
|
||||
"on_tool_end", "ignore_agent", verbose, output, **kwargs
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
await self._handle_event(
|
||||
"on_tool_error", "ignore_agent", verbose, error, **kwargs
|
||||
)
|
||||
|
||||
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||
"""Run when text is printed."""
|
||||
await self._handle_event("on_text", None, verbose, text, **kwargs)
|
||||
|
||||
async def on_agent_action(
|
||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent action."""
|
||||
await self._handle_event(
|
||||
"on_agent_action", "ignore_agent", verbose, action, **kwargs
|
||||
)
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when agent finishes."""
|
||||
await self._handle_event(
|
||||
"on_agent_finish", "ignore_agent", verbose, finish, **kwargs
|
||||
)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
770
langchain/callbacks/manager.py
Normal file
770
langchain/callbacks/manager.py
Normal file
@@ -0,0 +1,770 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.base import TracerSession
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
|
||||
"tracing_callback", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
cb = OpenAICallbackHandler()
|
||||
openai_callback_var.set(cb)
|
||||
yield cb
|
||||
openai_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
"""Get OpenAI callback handler in a context manager."""
|
||||
cb = LangChainTracer()
|
||||
session = cb.load_session(session_name)
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
tracing_callback_var.set(None)
|
||||
|
||||
|
||||
def _handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
getattr(handler, event_name)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# TODO: switch this to use logging
|
||||
print(f"Error in {event_name} callback: {e}")
|
||||
|
||||
|
||||
async def _ahandle_event_for_handler(
|
||||
handler: BaseCallbackHandler,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(event, *args, **kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: switch this to use logging
|
||||
print(f"Error in {event_name} callback: {e}")
|
||||
|
||||
|
||||
async def _ahandle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Generic event handler for AsyncCallbackManager."""
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_ahandle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, *args, **kwargs
|
||||
)
|
||||
for handler in handlers
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BaseRunManager(RunManagerMixin):
|
||||
"""Base class for run manager (a bound callback manager)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler],
|
||||
parent_run_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize run manager."""
|
||||
self.run_id = run_id
|
||||
self.handlers = handlers
|
||||
self.inheritable_handlers = inheritable_handlers
|
||||
self.parent_run_id = parent_run_id
|
||||
|
||||
|
||||
class RunManager(BaseRunManager):
|
||||
"""Sync Run Manager."""
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when text is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
None,
|
||||
text,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncRunManager(BaseRunManager):
|
||||
"""Async Run Manager."""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when text is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
None,
|
||||
text,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
"""Callback manager for LLM run."""
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token=token,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
"ignore_llm",
|
||||
response,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
"ignore_llm",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Async callback manager for LLM run."""
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: Optional[str] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
"ignore_llm",
|
||||
response,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
"ignore_llm",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def get_child(self) -> Callbacks:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
"ignore_chain",
|
||||
outputs,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
"ignore_chain",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
"ignore_agent",
|
||||
action,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
"ignore_agent",
|
||||
finish,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
||||
"""Async callback manager for chain run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
"ignore_chain",
|
||||
outputs,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
"ignore_chain",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
"ignore_agent",
|
||||
action,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
"ignore_agent",
|
||||
finish,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
||||
"""Callback manager for tool run."""
|
||||
|
||||
def get_child(self) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
"ignore_agent",
|
||||
output,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
"ignore_agent",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
"""Async callback manager for tool run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
"ignore_agent",
|
||||
output,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
"ignore_agent",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForChainRun:
|
||||
"""Run when chain starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chain_start",
|
||||
"ignore_chain",
|
||||
serialized,
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForChainRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: Optional[str] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForToolRun:
|
||||
"""Run when tool starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_tool_start",
|
||||
"ignore_agent",
|
||||
serialized,
|
||||
input_str,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CallbackManagerForToolRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
inheritable_callbacks: Optional[
|
||||
Union[BaseCallbackManager, List[BaseCallbackHandler]]
|
||||
] = None,
|
||||
local_callbacks: Optional[
|
||||
Union[BaseCallbackManager, List[BaseCallbackHandler]]
|
||||
] = None,
|
||||
verbose: bool = False,
|
||||
) -> Optional[BaseCallbackManager]:
|
||||
"""Configure the callback manager."""
|
||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
||||
|
||||
|
||||
class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the handler is async."""
|
||||
return True
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForLLMRun:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForChainRun:
|
||||
"""Run when chain starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_start",
|
||||
"ignore_chain",
|
||||
serialized,
|
||||
inputs,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForChainRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: Optional[str] = None,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForToolRun:
|
||||
"""Run when tool starts running."""
|
||||
if run_id is None:
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_start",
|
||||
"ignore_agent",
|
||||
serialized,
|
||||
input_str,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForToolRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
inheritable_callbacks: Optional[
|
||||
Union[BaseCallbackManager, List[BaseCallbackHandler]]
|
||||
] = None,
|
||||
local_callbacks: Optional[
|
||||
Union[BaseCallbackManager, List[BaseCallbackHandler]]
|
||||
] = None,
|
||||
verbose: bool = False,
|
||||
) -> Optional[BaseCallbackManager]:
|
||||
"""Configure the callback manager."""
|
||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
||||
|
||||
|
||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
) -> T:
|
||||
"""Configure the callback manager."""
|
||||
callback_manager = callback_manager_cls([])
|
||||
if inheritable_callbacks or local_callbacks:
|
||||
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
||||
inheritable_callbacks_: List[BaseCallbackHandler] = (
|
||||
inheritable_callbacks or []
|
||||
)
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=inheritable_callbacks_,
|
||||
inheritable_handlers=inheritable_callbacks_,
|
||||
)
|
||||
else:
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=inheritable_callbacks.handlers,
|
||||
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
|
||||
parent_run_id=inheritable_callbacks.parent_run_id,
|
||||
)
|
||||
callback_manager = copy.deepcopy(callback_manager)
|
||||
local_handlers_ = (
|
||||
local_callbacks
|
||||
if isinstance(local_callbacks, list)
|
||||
else (local_callbacks.handlers if local_callbacks else [])
|
||||
)
|
||||
for handler in local_handlers_:
|
||||
callback_manager.add_handler(copy.deepcopy(handler), False)
|
||||
|
||||
tracer = tracing_callback_var.get()
|
||||
open_ai = openai_callback_var.get()
|
||||
tracing_enabled_ = (
|
||||
os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None
|
||||
)
|
||||
if verbose or tracing_enabled_ or open_ai is not None:
|
||||
if verbose and not any(
|
||||
isinstance(handler, StdOutCallbackHandler)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(StdOutCallbackHandler(), False)
|
||||
|
||||
if tracing_enabled_ and not any(
|
||||
isinstance(handler, LangChainTracer)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
if tracer:
|
||||
callback_manager.add_handler(copy.deepcopy(tracer), True)
|
||||
else:
|
||||
handler = LangChainTracer()
|
||||
handler.load_default_session()
|
||||
callback_manager.add_handler(handler, True)
|
||||
if open_ai is not None and not any(
|
||||
isinstance(handler, OpenAICallbackHandler)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(open_ai, True)
|
||||
|
||||
return callback_manager
|
||||
|
||||
|
||||
class NullCallbackManagerForChainRun(CallbackManagerForChainRun):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("", [], [])
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def get_child(self) -> Callbacks:
|
||||
return None
|
||||
@@ -1,127 +0,0 @@
|
||||
"""A shared CallbackManager."""
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class Singleton:
|
||||
"""A thread-safe singleton class that can be inherited from."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> Any:
|
||||
"""Create a new shared instance of the class."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
# Another thread could have created the instance
|
||||
# before we acquired the lock. So check that the
|
||||
# instance is still nonexistent.
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
"""A thread-safe singleton CallbackManager."""
|
||||
|
||||
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_end(response, **kwargs)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_new_token(token, **kwargs)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_error(error, **kwargs)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_end(outputs, **kwargs)
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_chain_error(error, **kwargs)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_action(action, **kwargs)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_end(output, **kwargs)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_error(error, **kwargs)
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_finish(finish, **kwargs)
|
||||
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a callback to the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.add_handler(callback)
|
||||
|
||||
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Remove a callback from the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.remove_handler(callback)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
with self._lock:
|
||||
self._callback_manager.handlers = handlers
|
||||
@@ -1,12 +1,5 @@
|
||||
"""Tracers that record execution of LangChain runs."""
|
||||
|
||||
from langchain.callbacks.tracers.base import SharedTracer, Tracer
|
||||
from langchain.callbacks.tracers.langchain import BaseLangChainTracer
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
|
||||
|
||||
class SharedLangChainTracer(SharedTracer, BaseLangChainTracer):
|
||||
"""Shared tracer that records LangChain execution to LangChain endpoint."""
|
||||
|
||||
|
||||
class LangChainTracer(Tracer, BaseLangChainTracer):
|
||||
"""Tracer that records LangChain execution to LangChain endpoint."""
|
||||
__all__ = ["LangChainTracer"]
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Base interfaces for tracing runs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.shared import Singleton
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
@@ -16,7 +14,7 @@ from langchain.callbacks.tracers.schemas import (
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
@@ -26,13 +24,25 @@ class TracerException(Exception):
|
||||
class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Base interface for tracers."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
||||
self.session: Optional[TracerSession] = None
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
if isinstance(child_run, LLMRun):
|
||||
parent_run.child_llm_runs.append(child_run)
|
||||
elif isinstance(child_run, ChainRun):
|
||||
parent_run.child_chain_runs.append(child_run)
|
||||
elif isinstance(child_run, ToolRun):
|
||||
parent_run.child_tool_runs.append(child_run)
|
||||
else:
|
||||
raise TracerException(f"Invalid run type: {type(child_run)}")
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
@@ -42,15 +52,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
|
||||
@abstractmethod
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
|
||||
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
|
||||
"""NOT thread safe, do not call this method from multiple threads."""
|
||||
session_create = TracerSessionCreate(name=name, extra=kwargs)
|
||||
session = self._persist_session(session_create)
|
||||
self._session = session
|
||||
self.session = session
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
@@ -61,283 +67,232 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
|
||||
@_execution_order.setter
|
||||
@abstractmethod
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
|
||||
@_session.setter
|
||||
@abstractmethod
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
|
||||
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Start a trace for a run."""
|
||||
self._execution_order += 1
|
||||
|
||||
if self._stack:
|
||||
if not (
|
||||
isinstance(self._stack[-1], ChainRun)
|
||||
or isinstance(self._stack[-1], ToolRun)
|
||||
):
|
||||
if run.parent_uuid:
|
||||
parent_run = self.run_map[run.parent_uuid]
|
||||
if parent_run:
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException(
|
||||
"Cannot add child run to an LLM run. "
|
||||
"LLM runs are not allowed to have children."
|
||||
)
|
||||
self._add_child_run(parent_run, run)
|
||||
else:
|
||||
raise TracerException(
|
||||
f"Nested {run.__class__.__name__} can only be"
|
||||
f" logged inside a ChainRun or ToolRun"
|
||||
f"Parent run with UUID {run.parent_uuid} not found."
|
||||
)
|
||||
self._add_child_run(self._stack[-1], run)
|
||||
self._stack.append(run)
|
||||
|
||||
def _end_trace(self) -> None:
|
||||
self.run_map[run.uuid] = run
|
||||
|
||||
def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""End a trace for a run."""
|
||||
run = self._stack.pop()
|
||||
if not self._stack:
|
||||
self._execution_order = 1
|
||||
if not run.parent_uuid:
|
||||
self._persist_run(run)
|
||||
else:
|
||||
parent_run = self.run_map.get(run.parent_uuid)
|
||||
if parent_run is None:
|
||||
raise TracerException(
|
||||
f"Parent run with UUID {run.parent_uuid} not found."
|
||||
)
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
||||
if run.child_execution_order > parent_run.child_execution_order:
|
||||
parent_run.child_execution_order = run.child_execution_order
|
||||
self.run_map.pop(run.uuid)
|
||||
|
||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
if parent_run_id is None:
|
||||
return 1
|
||||
|
||||
parent_run = self.run_map.get(parent_run_id)
|
||||
if parent_run is None:
|
||||
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
|
||||
|
||||
if isinstance(parent_run, LLMRun):
|
||||
raise TracerException("LLM Runs are not allowed to have children. ")
|
||||
|
||||
return parent_run.child_execution_order + 1
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
|
||||
if run_id is None:
|
||||
run_id = str(uuid4())
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
llm_run = LLMRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
prompts=prompts,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
session_id=self.session.id,
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Handle a new token for an LLM run."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: str, **kwargs: Any) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
||||
|
||||
llm_run = self.run_map.get(run_id)
|
||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].response = response
|
||||
|
||||
self._end_trace()
|
||||
llm_run.response = response
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
self._end_trace(llm_run)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Handle an error for an LLM run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
||||
|
||||
llm_run = self.run_map.get(run_id)
|
||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
self._stack[-1].error = repr(error)
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
|
||||
self._end_trace()
|
||||
llm_run.error = repr(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
self._end_trace(llm_run)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a chain run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
chain_run = ChainRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
session_id=self.session.id,
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], *, run_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""End a trace for a chain run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
chain_run = self.run_map.get(run_id)
|
||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].outputs = outputs
|
||||
|
||||
self._end_trace()
|
||||
chain_run.outputs = outputs
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
self._end_trace(chain_run)
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Handle an error for a chain run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ChainRun):
|
||||
chain_run = self.run_map.get(run_id)
|
||||
if chain_run is None or not isinstance(chain_run, ChainRun):
|
||||
raise TracerException("No ChainRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
|
||||
self._end_trace()
|
||||
chain_run.error = repr(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
self._end_trace(chain_run)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: str,
|
||||
parent_run_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a tool run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
"Initialize a session with `new_session()` before starting a trace."
|
||||
)
|
||||
if self.session is None:
|
||||
self.session = self.load_default_session()
|
||||
|
||||
execution_order = self._get_execution_order(parent_run_id)
|
||||
tool_run = ToolRun(
|
||||
uuid=run_id,
|
||||
parent_uuid=parent_run_id,
|
||||
serialized=serialized,
|
||||
# TODO: this is duplicate info as above, not needed.
|
||||
action=str(serialized),
|
||||
tool_input=input_str,
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=self._execution_order,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
session_id=self._session.id,
|
||||
id=self._generate_id(),
|
||||
session_id=self.session.id,
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: str, *, run_id: str, **kwargs: Any) -> None:
|
||||
"""End a trace for a tool run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
tool_run = self.run_map.get(run_id)
|
||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].output = output
|
||||
|
||||
self._end_trace()
|
||||
tool_run.output = output
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
self._end_trace(tool_run)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Handle an error for a tool run."""
|
||||
if not self._stack or not isinstance(self._stack[-1], ToolRun):
|
||||
tool_run = self.run_map.get(run_id)
|
||||
if tool_run is None or not isinstance(tool_run, ToolRun):
|
||||
raise TracerException("No ToolRun found to be traced")
|
||||
|
||||
self._stack[-1].end_time = datetime.utcnow()
|
||||
self._stack[-1].error = repr(error)
|
||||
tool_run.error = repr(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
self._end_trace(tool_run)
|
||||
|
||||
self._end_trace()
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Handle a text message."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Handle an agent finish message."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class Tracer(BaseTracer, ABC):
|
||||
"""A non-thread safe implementation of the BaseTracer interface."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a tracer."""
|
||||
self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
self._tracer_execution_order = 1
|
||||
self._tracer_session: Optional[TracerSession] = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TracerStack(threading.local):
|
||||
"""A stack of runs used for logging."""
|
||||
|
||||
stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list)
|
||||
execution_order: int = 1
|
||||
|
||||
|
||||
class SharedTracer(Singleton, BaseTracer, ABC):
|
||||
"""A thread-safe Singleton implementation of BaseTracer."""
|
||||
|
||||
_tracer_stack = TracerStack()
|
||||
_tracer_session = None
|
||||
|
||||
@property
|
||||
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
|
||||
"""Get the tracer stack."""
|
||||
return self._tracer_stack.stack
|
||||
|
||||
@property
|
||||
def _execution_order(self) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
return self._tracer_stack.execution_order
|
||||
|
||||
@_execution_order.setter
|
||||
def _execution_order(self, value: int) -> None:
|
||||
"""Set the execution order for a run."""
|
||||
self._tracer_stack.execution_order = value
|
||||
|
||||
@property
|
||||
def _session(self) -> Optional[TracerSession]:
|
||||
"""Get the tracing session."""
|
||||
return self._tracer_session
|
||||
|
||||
@_session.setter
|
||||
def _session(self, value: TracerSession) -> None:
|
||||
"""Set the tracing session."""
|
||||
with self._lock:
|
||||
# TODO: currently, we are only checking current thread's stack.
|
||||
# Need to make sure that we are not in the middle of a trace
|
||||
# in any thread.
|
||||
if self._stack:
|
||||
raise TracerException(
|
||||
"Cannot set a session while a trace is being recorded"
|
||||
)
|
||||
self._tracer_session = value
|
||||
def __copy__(self) -> BaseTracer:
|
||||
"""Copy the tracer."""
|
||||
return self
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
@@ -18,14 +17,17 @@ from langchain.callbacks.tracers.schemas import (
|
||||
)
|
||||
|
||||
|
||||
class BaseLangChainTracer(BaseTracer, ABC):
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
always_verbose: bool = True
|
||||
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
self.session = self.load_session(session_name)
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
@@ -59,54 +61,29 @@ class BaseLangChainTracer(BaseTracer, ABC):
|
||||
session = TracerSession(id=1, **session_create.dict())
|
||||
return session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
def _load_session(self, session_name: Optional[str] = None) -> TracerSession:
|
||||
"""Load a session from the tracer."""
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions?name={session_name}",
|
||||
headers=self._headers,
|
||||
)
|
||||
url = f"{self._endpoint}/sessions"
|
||||
if session_name:
|
||||
url += f"?name={session_name}"
|
||||
r = requests.get(url, headers=self._headers)
|
||||
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logging.warning(
|
||||
f"Failed to load session {session_name}, using empty session: {e}"
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSession(id=1)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
return self._load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{self._endpoint}/sessions",
|
||||
headers=self._headers,
|
||||
)
|
||||
# Use the first session result
|
||||
tracer_session = TracerSession(**r.json()[0])
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to default session, using empty session: {e}")
|
||||
tracer_session = TracerSession(id=1)
|
||||
self._session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
if isinstance(child_run, LLMRun):
|
||||
parent_run.child_llm_runs.append(child_run)
|
||||
elif isinstance(child_run, ChainRun):
|
||||
parent_run.child_chain_runs.append(child_run)
|
||||
else:
|
||||
parent_run.child_tool_runs.append(child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
||||
return self._load_session("default")
|
||||
|
||||
@@ -32,11 +32,13 @@ class TracerSession(TracerSessionBase):
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
id: Optional[Union[int, str]] = None
|
||||
uuid: str
|
||||
parent_uuid: Optional[str] = None
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
serialized: Dict[str, Any]
|
||||
session_id: int
|
||||
error: Optional[str] = None
|
||||
@@ -57,7 +59,6 @@ class ChainRun(BaseRun):
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolRun(BaseRun):
|
||||
@@ -69,7 +70,6 @@ class ToolRun(BaseRun):
|
||||
child_llm_runs: List[LLMRun] = Field(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = Field(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
|
||||
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
|
||||
@@ -5,12 +5,16 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
NullCallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.requests import TextRequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class APIChain(Chain):
|
||||
@@ -61,16 +65,18 @@ class APIChain(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: CallbackManagerForChainRun = NullCallbackManagerForChainRun(),
|
||||
) -> Dict[str, str]:
|
||||
question = inputs[self.question_key]
|
||||
api_url = self.api_request_chain.predict(
|
||||
question=question, api_docs=self.api_docs
|
||||
)
|
||||
self.callback_manager.on_text(
|
||||
api_url, color="green", end="\n", verbose=self.verbose
|
||||
question=question, api_docs=self.api_docs, callbacks=run_manager.get_child()
|
||||
)
|
||||
run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
|
||||
api_response = self.requests_wrapper.get(api_url)
|
||||
self.callback_manager.on_text(
|
||||
run_manager.on_text(
|
||||
api_response, color="yellow", end="\n", verbose=self.verbose
|
||||
)
|
||||
answer = self.api_answer_chain.predict(
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
NullCallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema import BaseMemory
|
||||
|
||||
|
||||
@@ -21,9 +30,8 @@ class Chain(BaseModel, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callback_manager: BaseCallbackManager = Field(
|
||||
default_factory=get_callback_manager, exclude=True
|
||||
)
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
@@ -37,15 +45,16 @@ class Chain(BaseModel, ABC):
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
@@ -82,15 +91,26 @@ class Chain(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: CallbackManagerForChainRun = NullCallbackManagerForChainRun(),
|
||||
) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
raise NotImplementedError("Async call not supported for this chain type.")
|
||||
|
||||
def __call__(
|
||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@@ -104,21 +124,31 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
self.callback_manager.on_chain_start(
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
outputs = self._call(inputs)
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
run_manager.on_chain_end(outputs)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
async def acall(
|
||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@@ -132,30 +162,24 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = await self._acall(inputs)
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
await run_manager.on_chain_end(outputs)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
def prep_outputs(
|
||||
@@ -195,11 +219,13 @@ class Chain(BaseModel, ABC):
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs) for inputs in input_list]
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
@@ -210,17 +236,17 @@ class Chain(BaseModel, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0])[self.output_keys[0]]
|
||||
return self(args[0], callbacks=callbacks)[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs)[self.output_keys[0]]
|
||||
return self(kwargs, callbacks=callbacks)[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
async def arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
@@ -231,10 +257,10 @@ class Chain(BaseModel, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (await self.acall(args[0]))[self.output_keys[0]]
|
||||
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return (await self.acall(kwargs))[self.output_keys[0]]
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Chain for applying constitutional principles to the outputs of another chain."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class ConstitutionalChain(Chain):
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
@@ -15,7 +16,7 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document
|
||||
from langchain.schema import BaseMessage, BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
# Depending on the memory type and configuration, the chat history format may differ.
|
||||
|
||||
@@ -5,11 +5,17 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
@@ -53,21 +59,39 @@ class LLMChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.apply([inputs])[0]
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return self.apply([inputs], run_manager=run_manager)[0]
|
||||
|
||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
return self.llm.generate_prompt(prompts, stop)
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
|
||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
async def agenerate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list)
|
||||
return await self.llm.agenerate_prompt(prompts, stop)
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
|
||||
def prep_prompts(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
@@ -79,7 +103,8 @@ class LLMChain(Chain):
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@@ -88,7 +113,9 @@ class LLMChain(Chain):
|
||||
return prompts, stop
|
||||
|
||||
async def aprep_prompts(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
@@ -100,12 +127,8 @@ class LLMChain(Chain):
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
_text, end="\n", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if run_manager:
|
||||
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
@@ -113,14 +136,22 @@ class LLMChain(Chain):
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
def apply(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = self.generate(input_list)
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
return self.create_outputs(response)
|
||||
|
||||
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
async def aapply(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = await self.agenerate(input_list)
|
||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||
return self.create_outputs(response)
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
@@ -131,13 +162,18 @@ class LLMChain(Chain):
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return (await self.aapply([inputs]))[0]
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return (await self.aapply([inputs], run_manager=run_manager))[0]
|
||||
|
||||
def predict(self, **kwargs: Any) -> str:
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
@@ -148,12 +184,13 @@ class LLMChain(Chain):
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs)[self.output_key]
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
async def apredict(self, **kwargs: Any) -> str:
|
||||
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
@@ -164,31 +201,35 @@ class LLMChain(Chain):
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return (await self.acall(kwargs))[self.output_key]
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
|
||||
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(**kwargs)
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
async def apredict_and_parse(
|
||||
self, **kwargs: Any
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call apredict and then parse the results."""
|
||||
result = await self.apredict(**kwargs)
|
||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list)
|
||||
result = self.apply(input_list, run_manager=run_manager)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
@@ -202,10 +243,12 @@ class LLMChain(Chain):
|
||||
return result
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = await self.aapply(input_list)
|
||||
result = await self.aapply(input_list, run_manager=run_manager)
|
||||
return self._parse_result(result)
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Dict, List
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numexpr
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LLMMathChain(Chain):
|
||||
@@ -68,15 +72,19 @@ class LLMMathChain(Chain):
|
||||
# Remove any leading and trailing brackets from the output
|
||||
return re.sub(r"^\[|\]$", "", output)
|
||||
|
||||
def _process_llm_result(self, llm_output: str) -> Dict[str, str]:
|
||||
self.callback_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
def _process_llm_result(
|
||||
self, llm_output: str, run_manager: Optional[CallbackManagerForChainRun] = None
|
||||
) -> Dict[str, str]:
|
||||
if run_manager:
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
if run_manager:
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
@@ -86,30 +94,21 @@ class LLMMathChain(Chain):
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
llm_output, color="green", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(
|
||||
llm_output, color="green", verbose=self.verbose
|
||||
)
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
if run_manager:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await self.callback_manager.on_text(
|
||||
output, color="yellow", verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await self.callback_manager.on_text(
|
||||
output, color="yellow", verbose=self.verbose
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
@@ -119,30 +118,35 @@ class LLMMathChain(Chain):
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(
|
||||
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
|
||||
)
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
|
||||
if run_manager:
|
||||
run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = llm_executor.predict(
|
||||
question=inputs[self.input_key], stop=["```output"]
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return self._process_llm_result(llm_output)
|
||||
return self._process_llm_result(llm_output, run_manager=run_manager)
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(
|
||||
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
|
||||
)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_text(
|
||||
inputs[self.input_key], verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
|
||||
if run_manager:
|
||||
await run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = await llm_executor.apredict(
|
||||
question=inputs[self.input_key], stop=["```output"]
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output)
|
||||
return await self._aprocess_llm_result(llm_output, run_manager=run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -21,7 +22,6 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BaseQAWithSourcesChain(Chain, ABC):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load question answering with sources chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
@@ -14,7 +15,6 @@ from langchain.chains.qa_with_sources import (
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load question answering chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
@@ -14,7 +15,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load summarizing chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
@@ -8,7 +9,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -2,6 +2,10 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.anthropic import _AnthropicCommon
|
||||
from langchain.schema import (
|
||||
@@ -85,7 +89,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
@@ -98,10 +105,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
for data in stream_resp:
|
||||
delta = data["completion"][len(completion) :]
|
||||
completion = data["completion"]
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
delta,
|
||||
)
|
||||
else:
|
||||
response = self.client.completion(**params)
|
||||
completion = response["completion"]
|
||||
@@ -109,7 +116,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
@@ -122,15 +132,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
async for data in stream_resp:
|
||||
delta = data["completion"][len(completion) :]
|
||||
completion = data["completion"]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
response = await self.client.acompletion(**params)
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, validator
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +35,19 @@ def _get_verbosity() -> bool:
|
||||
class BaseChatModel(BaseLanguageModel, ABC):
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -34,98 +55,130 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
results = [self._generate(m, stop=stop) for m in messages]
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
message_strings = [get_buffer_string(m) for m in messages]
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, message_strings
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
try:
|
||||
results = [
|
||||
self._generate(m, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._generate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
run_manager.on_llm_end(output)
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
results = await asyncio.gather(
|
||||
*[self._agenerate(m, stop=stop) for m in messages]
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
message_strings = [get_buffer_string(m) for m in messages]
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, message_strings
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._agenerate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await run_manager.on_llm_end(output)
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = self.generate(prompt_messages, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = await self.agenerate(prompt_messages, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
def __call__(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> BaseMessage:
|
||||
return self._generate(messages, stop=stop).generations[0].message
|
||||
generation = self.generate(
|
||||
[messages], stop=stop, callbacks=callbacks
|
||||
).generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
|
||||
result = self([HumanMessage(content=message)], stop=stop)
|
||||
@@ -134,15 +187,21 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
|
||||
class SimpleChatModel(BaseChatModel):
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop)
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Simpler interface."""
|
||||
|
||||
@@ -14,6 +14,10 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
@@ -242,7 +246,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
if self.streaming:
|
||||
@@ -255,10 +262,8 @@ class ChatOpenAI(BaseChatModel):
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
inner_completion += token
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{"content": inner_completion, "role": role}
|
||||
)
|
||||
@@ -287,7 +292,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
if self.streaming:
|
||||
@@ -300,16 +308,8 @@ class ChatOpenAI(BaseChatModel):
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
inner_completion += token
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{"content": inner_completion, "role": role}
|
||||
)
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import BaseMessage, ChatResult
|
||||
|
||||
@@ -33,13 +37,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
return_pl_id: Optional[bool] = False
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = super()._generate(messages, stop)
|
||||
generated_responses = super()._generate(messages, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
message_dicts, params = super()._create_message_dicts(messages, stop)
|
||||
for i, generation in enumerate(generated_responses.generations):
|
||||
@@ -67,13 +74,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
return generated_responses
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = await super()._agenerate(messages, stop)
|
||||
generated_responses = await super()._agenerate(messages, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
message_dicts, params = super()._create_message_dicts(messages, stop)
|
||||
for i, generation in enumerate(generated_responses.generations):
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_creation import (
|
||||
TaskCreationChain,
|
||||
@@ -13,7 +14,6 @@ from langchain.experimental.autonomous_agents.baby_agi.task_execution import (
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import (
|
||||
TaskPrioritizationChain,
|
||||
)
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskCreationChain(LLMChain):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskExecutionChain(LLMChain):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskPrioritizationChain(LLMChain):
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -106,7 +107,12 @@ class AI21(LLM):
|
||||
"""Return type of llm."""
|
||||
return "ai21"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to AI21's complete endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -200,7 +201,12 @@ class AlephAlpha(LLM):
|
||||
"""Return type of llm."""
|
||||
return "alpeh_alpha"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Aleph Alpha's completion endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,10 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -142,7 +146,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
|
||||
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
r"""Call out to Anthropic's completion endpoint.
|
||||
|
||||
Args:
|
||||
@@ -171,9 +180,8 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
for data in stream_resp:
|
||||
delta = data["completion"][len(current_completion) :]
|
||||
current_completion = data["completion"]
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta, verbose=self.verbose, **data
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta, **data)
|
||||
return current_completion
|
||||
response = self.client.completion(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
@@ -182,7 +190,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
)
|
||||
return response["completion"]
|
||||
|
||||
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
if self.streaming:
|
||||
@@ -195,14 +208,8 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
async for data in stream_resp:
|
||||
delta = data["completion"][len(current_completion) :]
|
||||
current_completion = data["completion"]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
delta, verbose=self.verbose, **data
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta, verbose=self.verbose, **data
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta, **data)
|
||||
return current_completion
|
||||
response = await self.client.acompletion(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -80,7 +81,12 @@ class Banana(LLM):
|
||||
"""Return type of llm."""
|
||||
return "banana"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call to Banana endpoint."""
|
||||
try:
|
||||
import banana_dev as banana
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
"""Base interface for large language models to expose."""
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import Extra, Field, validator
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import Generation, LLMResult, PromptValue
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
@@ -59,7 +68,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -67,15 +77,16 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
@@ -90,30 +101,45 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return self.generate(prompt_strings, stop=stop)
|
||||
return self.generate(prompt_strings, stop=stop, callbacks=callbacks)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return await self.agenerate(prompt_strings, stop=stop)
|
||||
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
|
||||
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
@@ -124,21 +150,31 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
try:
|
||||
output = self._generate(prompts, stop=stop)
|
||||
output = (
|
||||
self._generate(prompts, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
run_manager.on_llm_end(output)
|
||||
return output
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@@ -149,15 +185,19 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
if len(missing_prompts) > 0:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
|
||||
run_manager = self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, missing_prompts
|
||||
)
|
||||
try:
|
||||
new_results = self._generate(missing_prompts, stop=stop)
|
||||
new_results = (
|
||||
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._generate(missing_prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
||||
run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
@@ -167,36 +207,38 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
try:
|
||||
output = await self._agenerate(prompts, stop=stop)
|
||||
output = (
|
||||
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
await run_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@@ -207,32 +249,22 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
) = get_prompts(params, prompts)
|
||||
if len(missing_prompts) > 0:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
)
|
||||
try:
|
||||
new_results = await self._agenerate(missing_prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(
|
||||
new_results, verbose=self.verbose
|
||||
new_results = (
|
||||
await self._agenerate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(missing_prompts, stop=stop)
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
await run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
@@ -241,9 +273,15 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def __call__(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
return self.generate([prompt], stop=stop).generations[0][0].text
|
||||
return (
|
||||
self.generate([prompt], stop=stop, callbacks=callbacks)
|
||||
.generations[0][0]
|
||||
.text
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@@ -307,30 +345,56 @@ class LLM(BaseLLM):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
raise NotImplementedError("Async generation not implemented for this LLM.")
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# TODO: add caching here.
|
||||
generations = []
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
for prompt in prompts:
|
||||
text = self._call(prompt, stop=stop)
|
||||
text = (
|
||||
self._call(prompt, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(prompt, stop=stop)
|
||||
)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
generations = []
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
for prompt in prompts:
|
||||
text = await self._acall(prompt, stop=stop)
|
||||
text = (
|
||||
await self._acall(prompt, stop=stop, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(prompt, stop=stop)
|
||||
)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -81,7 +82,12 @@ class CerebriumAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "cerebriumai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call to CerebriumAI endpoint."""
|
||||
try:
|
||||
from cerebrium import model_api_request
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -100,7 +101,12 @@ class Cohere(LLM):
|
||||
"""Return type of llm."""
|
||||
return "cohere"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Cohere's generate endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -60,7 +61,12 @@ class DeepInfra(LLM):
|
||||
"""Return type of llm."""
|
||||
return "deepinfra"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to DeepInfra's inference API endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
@@ -15,7 +16,12 @@ class FakeListLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake-list"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
response = self.responses[self.i]
|
||||
self.i += 1
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -81,7 +82,12 @@ class ForefrontAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "forefrontai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to ForefrontAI's complete endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -130,7 +131,12 @@ class GooseAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "gooseai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call the GooseAI API."""
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -159,7 +160,12 @@ class GPT4All(LLM):
|
||||
"""Return the type of llm."""
|
||||
return "gpt4all"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
r"""Call out to GPT4All's generate method.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -88,7 +89,12 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Return type of llm."""
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -84,7 +85,12 @@ class HuggingFaceHub(LLM):
|
||||
"""Return type of llm."""
|
||||
return "huggingface_hub"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -146,7 +147,12 @@ class HuggingFacePipeline(LLM):
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
response = self.pipeline(prompt)
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text.
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -154,7 +155,12 @@ class LlamaCpp(LLM):
|
||||
"""Return type of llm."""
|
||||
return "llama.cpp"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call the Llama model and return the output.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
@@ -42,7 +43,12 @@ class ManifestWrapper(LLM):
|
||||
"""Return type of llm."""
|
||||
return "manifest"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to LLM through Manifest."""
|
||||
if stop is not None and len(stop) != 1:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -69,7 +70,12 @@ class Modal(LLM):
|
||||
"""Return type of llm."""
|
||||
return "modal"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call to Modal endpoint."""
|
||||
params = self.model_kwargs or {}
|
||||
response = requests.post(
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -111,7 +112,12 @@ class NLPCloud(LLM):
|
||||
"""Return type of llm."""
|
||||
return "nlpcloud"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to NLPCloud's create endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -29,6 +29,10 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -253,7 +257,10 @@ class BaseOpenAI(BaseLLM):
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Call out to OpenAI's endpoint with k unique prompts.
|
||||
|
||||
@@ -286,11 +293,12 @@ class BaseOpenAI(BaseLLM):
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=_prompts, **params
|
||||
):
|
||||
self.callback_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
else:
|
||||
@@ -302,7 +310,10 @@ class BaseOpenAI(BaseLLM):
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Call out to OpenAI's endpoint async with k unique prompts."""
|
||||
params = self._invocation_params
|
||||
@@ -321,14 +332,8 @@ class BaseOpenAI(BaseLLM):
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, prompt=_prompts, **params
|
||||
):
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
@@ -704,7 +709,10 @@ class OpenAIChat(BaseLLM):
|
||||
return messages, params
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
if self.streaming:
|
||||
@@ -713,10 +721,10 @@ class OpenAIChat(BaseLLM):
|
||||
for stream_resp in completion_with_retry(self, messages=messages, **params):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
response += token
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token,
|
||||
)
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=response)]],
|
||||
)
|
||||
@@ -734,7 +742,10 @@ class OpenAIChat(BaseLLM):
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
if self.streaming:
|
||||
@@ -745,15 +756,9 @@ class OpenAIChat(BaseLLM):
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
response += token
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
return LLMResult(
|
||||
generations=[[Generation(text=response)]],
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -130,7 +131,12 @@ class Petals(LLM):
|
||||
"""Return type of llm."""
|
||||
return "petals"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call the Petals API."""
|
||||
params = self._default_params
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms import OpenAI, OpenAIChat
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -33,13 +37,16 @@ class PromptLayerOpenAI(OpenAI):
|
||||
return_pl_id: Optional[bool] = False
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = super()._generate(prompts, stop)
|
||||
generated_responses = super()._generate(prompts, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
@@ -69,12 +76,15 @@ class PromptLayerOpenAI(OpenAI):
|
||||
return generated_responses
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = await super()._agenerate(prompts, stop)
|
||||
generated_responses = await super()._agenerate(prompts, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
@@ -131,13 +141,16 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
||||
return_pl_id: Optional[bool] = False
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = super()._generate(prompts, stop)
|
||||
generated_responses = super()._generate(prompts, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
@@ -167,12 +180,15 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
||||
return generated_responses
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
|
||||
request_start_time = datetime.datetime.now().timestamp()
|
||||
generated_responses = await super()._agenerate(prompts, stop)
|
||||
generated_responses = await super()._agenerate(prompts, stop, run_manager)
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -78,7 +79,12 @@ class Replicate(LLM):
|
||||
"""Return type of model."""
|
||||
return "replicate"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call to replicate endpoint."""
|
||||
try:
|
||||
import replicate as replicate_python
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -204,7 +205,12 @@ class RWKV(LLM, BaseModel):
|
||||
|
||||
return decoded
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
r"""RWKV generation
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -194,7 +195,12 @@ class SagemakerEndpoint(LLM):
|
||||
"""Return type of llm."""
|
||||
return "sagemaker_endpoint"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Sagemaker inference endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Callable, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -208,5 +209,10 @@ class SelfHostedPipeline(LLM):
|
||||
def _llm_type(self) -> str:
|
||||
return "self_hosted_llm"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Callable, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.self_hosted import SelfHostedPipeline
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -198,5 +199,10 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
||||
def _llm_type(self) -> str:
|
||||
return "selfhosted_huggingface_pipeline"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -80,7 +81,12 @@ class StochasticAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "stochasticai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to StochasticAI's complete endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -117,7 +118,12 @@ class Writer(LLM):
|
||||
"""Return type of llm."""
|
||||
return "writer"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call out to Writer's complete endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import (
|
||||
@@ -13,7 +14,7 @@ from langchain.memory.prompt import (
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
|
||||
from langchain.schema import BaseMessage, get_buffer_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Type, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs import NetworkxEntityGraph
|
||||
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
||||
@@ -13,7 +14,6 @@ from langchain.memory.prompt import (
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Any, Dict, List, Type
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
|
||||
from langchain.schema import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationTokenBufferMemory(BaseChatMemory):
|
||||
|
||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@ from __future__ import annotations
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseOutputParser,
|
||||
OutputParserException,
|
||||
PromptValue,
|
||||
|
||||
@@ -171,45 +171,6 @@ class PromptValue(BaseModel, ABC):
|
||||
"""Return prompt as messages."""
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the message."""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
|
||||
|
||||
class BaseMemory(BaseModel, ABC):
|
||||
"""Base interface for memory in chains."""
|
||||
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
"""Base implementation for tools or skills."""
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, validate_arguments, validator
|
||||
from pydantic import BaseModel, Extra, Field, root_validator, validate_arguments
|
||||
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
|
||||
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
|
||||
@@ -28,7 +36,8 @@ class BaseTool(ABC, BaseModel):
|
||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||
return_direct: bool = False
|
||||
verbose: bool = False
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
callbacks: Callbacks = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -60,15 +69,16 @@ class BaseTool(ABC, BaseModel):
|
||||
if input_args is not None:
|
||||
input_args.validate(tool_input)
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
@@ -84,6 +94,7 @@ class BaseTool(ABC, BaseModel):
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool."""
|
||||
@@ -92,22 +103,22 @@ class BaseTool(ABC, BaseModel):
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
self.callback_manager.on_tool_start(
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, verbose=verbose_
|
||||
)
|
||||
run_manager = callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
|
||||
observation = self._run(*tool_args, **tool_kwargs)
|
||||
observation = self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
@@ -116,6 +127,7 @@ class BaseTool(ABC, BaseModel):
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool asynchronously."""
|
||||
@@ -124,42 +136,27 @@ class BaseTool(ABC, BaseModel):
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, verbose=verbose_
|
||||
)
|
||||
run_manager = await callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
args, kwargs = _to_args_and_kwargs(tool_input)
|
||||
observation = await self._arun(*args, **kwargs)
|
||||
observation = await self._arun(*args, run_manager=run_manager, **kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
else:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
await run_manager.on_tool_end(
|
||||
observation, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: str) -> str:
|
||||
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
return self.run(tool_input, callbacks=callbacks)
|
||||
|
||||
@@ -76,7 +76,7 @@ class SerpAPIWrapper(BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
async def arun(self, query: str) -> str:
|
||||
async def arun(self, query: str, **kwargs: Any) -> str:
|
||||
"""Use aiohttp to run query through SerpAPI and parse result."""
|
||||
|
||||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
|
||||
@@ -99,7 +99,7 @@ class SerpAPIWrapper(BaseModel):
|
||||
|
||||
return self._process_response(res)
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
def run(self, query: str, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result."""
|
||||
return self._process_response(self.results(query))
|
||||
|
||||
|
||||
0
tests/integration_tests/callbacks/__init__.py
Normal file
0
tests/integration_tests/callbacks/__init__.py
Normal file
80
tests/integration_tests/callbacks/test_langchain_tracer.py
Normal file
80
tests/integration_tests/callbacks/test_langchain_tracer.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Integration tests for the langchain tracer module."""
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.callbacks import tracing_enabled
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
questions = [
|
||||
"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?",
|
||||
"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?",
|
||||
"Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?",
|
||||
"Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?",
|
||||
"Who is Beyonce's husband? What is his age raised to the 0.19 power?",
|
||||
]
|
||||
|
||||
|
||||
def test_tracing_sequential() -> None:
|
||||
os.environ["LANGCHAIN_TRACING"] = "true"
|
||||
|
||||
for q in questions[:3]:
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(q)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracing_concurrent() -> None:
|
||||
os.environ["LANGCHAIN_TRACING"] = "true"
|
||||
aiosession = ClientSession()
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
tasks = [agent.arun(q) for q in questions[:3]]
|
||||
await asyncio.gather(*tasks)
|
||||
await aiosession.close()
|
||||
|
||||
|
||||
def test_tracing_context_manager() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING"]
|
||||
with tracing_enabled() as session:
|
||||
assert session
|
||||
agent.run(questions[0]) # this should be traced
|
||||
|
||||
agent.run(questions[0]) # this should not be traced
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracing_context_manager_async() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_TRACING" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING"]
|
||||
|
||||
# start a background task
|
||||
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
|
||||
with tracing_enabled() as session:
|
||||
assert session
|
||||
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await task
|
||||
37
tests/integration_tests/callbacks/test_openai_callback.py
Normal file
37
tests/integration_tests/callbacks/test_openai_callback.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Integration tests for the langchain tracer module."""
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain import OpenAI
|
||||
from langchain.callbacks import get_openai_callback
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_callback() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
with get_openai_callback() as cb:
|
||||
llm("What is the square root of 4?")
|
||||
|
||||
total_tokens = cb.total_tokens
|
||||
assert total_tokens > 0
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
llm("What is the square root of 4?")
|
||||
llm("What is the square root of 4?")
|
||||
|
||||
assert cb.total_tokens == total_tokens * 2
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
await asyncio.gather(
|
||||
*[llm.agenerate(["What is the square root of 4?"]) for _ in range(3)]
|
||||
)
|
||||
|
||||
assert cb.total_tokens == total_tokens * 3
|
||||
|
||||
task = asyncio.create_task(llm.agenerate(["What is the square root of 4?"]))
|
||||
with get_openai_callback() as cb:
|
||||
await llm.agenerate(["What is the square root of 4?"])
|
||||
|
||||
await task
|
||||
assert cb.total_tokens == total_tokens
|
||||
@@ -3,7 +3,7 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.llms.anthropic import Anthropic
|
||||
from langchain.schema import LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.openai import OpenAI, OpenAIChat
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.llms.base import LLM
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -17,12 +16,72 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
ignore_agent_: bool = False
|
||||
always_verbose_: bool = False
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
llm_streams: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_actions: int = 0
|
||||
agent_ends: int = 0
|
||||
|
||||
|
||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
"""Base fake callback handler mixin for testing."""
|
||||
|
||||
def on_llm_start_common(self) -> None:
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end_common(self) -> None:
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_llm_new_token_common(self) -> None:
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_chain_start_common(self) -> None:
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end_common(self) -> None:
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start_common(self) -> None:
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end_common(self) -> None:
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_agent_action_common(self) -> None:
|
||||
self.agent_actions += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_agent_finish_common(self) -> None:
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_text_common(self) -> None:
|
||||
self.text += 1
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
@@ -39,164 +98,209 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
llm_streams: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_ends: int = 0
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_start_common()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_end_common()
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_error_common()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_start_common()
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_end_common()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_error_common()
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_start_common()
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_end_common()
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_error_common()
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.text += 1
|
||||
def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_agent_action_common()
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake async callback handler for testing."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
self.on_llm_start_common()
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
async def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_end_common()
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
self.on_llm_error_common()
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
self.on_chain_start_common()
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
async def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_chain_end_common()
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
self.on_chain_error_common()
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
self.on_tool_start_common()
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
async def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_tool_end_common()
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
self.on_tool_error_common()
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.text += 1
|
||||
async def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_agent_action_common()
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
async def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self
|
||||
|
||||
@@ -3,13 +3,9 @@ from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentFinish, LLMResult
|
||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
@@ -18,19 +14,26 @@ from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
manager.on_llm_end(LLMResult(generations=[]))
|
||||
manager.on_llm_error(Exception())
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, "")
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_manager.on_llm_new_token("foo")
|
||||
run_manager.on_text("foo")
|
||||
|
||||
run_manager = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager.on_chain_end({})
|
||||
run_manager.on_chain_error(Exception())
|
||||
run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool=""))
|
||||
run_manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager.on_text("foo")
|
||||
|
||||
run_manager = manager.on_tool_start({}, "")
|
||||
run_manager.on_tool_end("")
|
||||
run_manager.on_tool_error(Exception())
|
||||
run_manager.on_text("foo")
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
@@ -38,75 +41,60 @@ async def _test_callback_manager_async(
|
||||
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
await manager.on_llm_start({}, [])
|
||||
await manager.on_llm_end(LLMResult(generations=[]))
|
||||
await manager.on_llm_error(Exception())
|
||||
await manager.on_chain_start({"name": "foo"}, {})
|
||||
await manager.on_chain_end({})
|
||||
await manager.on_chain_error(Exception())
|
||||
await manager.on_tool_start({}, "")
|
||||
await manager.on_tool_end("")
|
||||
await manager.on_tool_error(Exception())
|
||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager = await manager.on_llm_start({}, [])
|
||||
await run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
await run_manager.on_llm_error(Exception())
|
||||
await run_manager.on_llm_new_token("foo")
|
||||
await run_manager.on_text("foo")
|
||||
|
||||
run_manager = await manager.on_chain_start({"name": "foo"}, {})
|
||||
await run_manager.on_chain_end({})
|
||||
await run_manager.on_chain_error(Exception())
|
||||
await run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool=""))
|
||||
await run_manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
await run_manager.on_text("foo")
|
||||
|
||||
run_manager = await manager.on_tool_start({}, "")
|
||||
await run_manager.on_tool_end("")
|
||||
await run_manager.on_tool_error(Exception())
|
||||
await run_manager.on_text("foo")
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
|
||||
for handler in handlers:
|
||||
if handler.always_verbose:
|
||||
assert handler.starts == 3
|
||||
assert handler.ends == 4
|
||||
assert handler.errors == 3
|
||||
else:
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def _test_callback_manager_pass_in_verbose(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
|
||||
manager.on_llm_error(Exception(), verbose=True)
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
||||
for handler in handlers:
|
||||
assert handler.starts == 3
|
||||
assert handler.starts == 4
|
||||
assert handler.ends == 4
|
||||
assert handler.errors == 3
|
||||
assert handler.text == 3
|
||||
|
||||
assert handler.llm_starts == 1
|
||||
assert handler.llm_ends == 1
|
||||
assert handler.llm_streams == 1
|
||||
|
||||
assert handler.chain_starts == 1
|
||||
assert handler.chain_ends == 1
|
||||
|
||||
assert handler.tool_starts == 1
|
||||
assert handler.tool_ends == 1
|
||||
|
||||
|
||||
def test_callback_manager() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=False)
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_callback_manager_pass_in_verbose() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_ignore_llm() -> None:
|
||||
"""Test ignore llm param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
|
||||
manager.on_llm_error(Exception(), verbose=True)
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
@@ -117,12 +105,12 @@ def test_ignore_llm() -> None:
|
||||
|
||||
def test_ignore_chain() -> None:
|
||||
"""Test ignore chain param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_chain_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
run_manager = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager.on_chain_end({})
|
||||
run_manager.on_chain_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
@@ -133,39 +121,24 @@ def test_ignore_chain() -> None:
|
||||
|
||||
def test_ignore_agent() -> None:
|
||||
"""Test ignore agent param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
||||
run_manager = manager.on_tool_start({}, "")
|
||||
run_manager.on_tool_end("")
|
||||
run_manager.on_tool_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
assert handler2.starts == 1
|
||||
assert handler2.ends == 2
|
||||
assert handler2.ends == 1
|
||||
assert handler2.errors == 1
|
||||
|
||||
|
||||
def test_shared_callback_manager() -> None:
|
||||
"""Test the SharedCallbackManager."""
|
||||
manager1 = SharedCallbackManager()
|
||||
manager2 = SharedCallbackManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
_test_callback_manager(manager1, handler1, handler2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeAsyncCallbackHandler()
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2])
|
||||
await _test_callback_manager_async(manager, handler1, handler2)
|
||||
@@ -174,8 +147,95 @@ async def test_async_callback_manager() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager_sync_handler() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler3 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2, handler3])
|
||||
await _test_callback_manager_async(manager, handler1, handler2, handler3)
|
||||
|
||||
|
||||
def test_callback_manager_inheritance() -> None:
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
)
|
||||
|
||||
callback_manager1 = CallbackManager([handler1, handler2])
|
||||
assert callback_manager1.handlers == [handler1, handler2]
|
||||
assert callback_manager1.inheritable_handlers == []
|
||||
|
||||
callback_manager2 = CallbackManager([])
|
||||
assert callback_manager2.handlers == []
|
||||
assert callback_manager2.inheritable_handlers == []
|
||||
|
||||
callback_manager2.set_handlers([handler1, handler2])
|
||||
assert callback_manager2.handlers == [handler1, handler2]
|
||||
assert callback_manager2.inheritable_handlers == [handler1, handler2]
|
||||
|
||||
callback_manager2.set_handlers([handler3, handler4], inherit=False)
|
||||
assert callback_manager2.handlers == [handler3, handler4]
|
||||
assert callback_manager2.inheritable_handlers == []
|
||||
|
||||
callback_manager2.add_handler(handler1)
|
||||
assert callback_manager2.handlers == [handler3, handler4, handler1]
|
||||
assert callback_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
callback_manager2.add_handler(handler2, inherit=False)
|
||||
assert callback_manager2.handlers == [handler3, handler4, handler1, handler2]
|
||||
assert callback_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
run_manager = callback_manager2.on_chain_start({"name": "foo"}, {})
|
||||
child_manager = run_manager.get_child()
|
||||
assert child_manager.handlers == [handler1]
|
||||
assert child_manager.inheritable_handlers == [handler1]
|
||||
|
||||
child_manager = child_manager.on_tool_start({}, "")
|
||||
assert child_manager.handlers == [handler1]
|
||||
assert child_manager.inheritable_handlers == [handler1]
|
||||
|
||||
child_manager = child_manager.get_child()
|
||||
assert child_manager.handlers == [handler1]
|
||||
assert child_manager.inheritable_handlers == [handler1]
|
||||
|
||||
|
||||
def test_callback_manager_configure() -> None:
|
||||
"""Test callback manager configuration."""
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
)
|
||||
|
||||
inheritable_callbacks = [handler1, handler2]
|
||||
local_callbacks = [handler3, handler4]
|
||||
configured_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=inheritable_callbacks,
|
||||
local_callbacks=local_callbacks,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert len(configured_manager.handlers) == 5
|
||||
assert len(configured_manager.inheritable_handlers) == 2
|
||||
assert configured_manager.inheritable_handlers == inheritable_callbacks
|
||||
assert configured_manager.handlers[:4] == inheritable_callbacks + local_callbacks
|
||||
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
|
||||
assert isinstance(configured_manager, CallbackManager)
|
||||
|
||||
async_local_callbacks = AsyncCallbackManager([handler3, handler4])
|
||||
async_configured_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=inheritable_callbacks,
|
||||
local_callbacks=async_local_callbacks,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
assert len(async_configured_manager.handlers) == 4
|
||||
assert len(async_configured_manager.inheritable_handlers) == 2
|
||||
assert async_configured_manager.inheritable_handlers == inheritable_callbacks
|
||||
assert async_configured_manager.handlers == inheritable_callbacks + [
|
||||
handler3,
|
||||
handler4,
|
||||
]
|
||||
assert isinstance(async_configured_manager, AsyncCallbackManager)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
@@ -12,9 +12,7 @@ from langchain.callbacks.tracers.base import (
|
||||
BaseTracer,
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
SharedTracer,
|
||||
ToolRun,
|
||||
Tracer,
|
||||
TracerException,
|
||||
TracerSession,
|
||||
)
|
||||
@@ -27,37 +25,45 @@ TEST_SESSION_ID = 2023
|
||||
@freeze_time("2023-01-01")
|
||||
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
return ChainRun(
|
||||
id=None,
|
||||
uuid="chain_uuid",
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=4,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_runs=[
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
id=None,
|
||||
uuid="tool_uuid",
|
||||
parent_uuid="chain_uuid",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_runs=[
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid="llm_uuid1",
|
||||
parent_uuid="tool_uuid",
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
@@ -65,13 +71,17 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid="llm_uuid2",
|
||||
parent_uuid="chain_uuid",
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
@@ -83,27 +93,25 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
|
||||
def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
"""Perform a nested run."""
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_chain_end(outputs={})
|
||||
chain_uuid = "chain_uuid"
|
||||
tool_uuid = "tool_uuid"
|
||||
llm_uuid1 = "llm_uuid1"
|
||||
llm_uuid2 = "llm_uuid2"
|
||||
|
||||
|
||||
def _add_child_run(
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
parent_run.child_runs.append(child_run)
|
||||
|
||||
|
||||
def _generate_id() -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
||||
tracer.on_tool_start(
|
||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSession:
|
||||
@@ -121,7 +129,7 @@ def load_default_session() -> TracerSession:
|
||||
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
|
||||
|
||||
|
||||
class FakeTracer(Tracer):
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -133,58 +141,6 @@ class FakeTracer(Tracer):
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
class FakeSharedTracer(SharedTracer):
|
||||
"""Fake shared tracer that records LangChain execution."""
|
||||
|
||||
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
with self._lock:
|
||||
self.runs.append(run)
|
||||
|
||||
def remove_runs(self) -> None:
|
||||
"""Remove all runs."""
|
||||
with self._lock:
|
||||
self.runs = []
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
@@ -201,12 +157,15 @@ class FakeSharedTracer(SharedTracer):
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
uuid = str(uuid4())
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
uuid=uuid,
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
@@ -216,20 +175,11 @@ def test_tracer_llm_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_session() -> None:
|
||||
"""Test tracer on an LLM run without a session."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
@@ -237,18 +187,21 @@ def test_tracer_llm_run_errors_no_start() -> None:
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=str(uuid4()))
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
uuid = str(uuid4())
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
uuid=uuid,
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
@@ -260,8 +213,8 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
@@ -269,12 +222,15 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
uuid = str(uuid4())
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
uuid=uuid,
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
@@ -284,20 +240,23 @@ def test_tracer_chain_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_end(outputs={})
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
uuid = str(uuid4())
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
uuid=uuid,
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
@@ -308,8 +267,8 @@ def test_tracer_tool_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_end("test", run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -318,21 +277,24 @@ def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
[_perform_nested_run(tracer) for _ in range(10)]
|
||||
assert tracer.runs == [_get_compare_run()] * 10
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = str(uuid4())
|
||||
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
uuid=uuid,
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=None,
|
||||
@@ -342,8 +304,8 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -351,13 +313,16 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = str(uuid4())
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
uuid=uuid,
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
@@ -367,8 +332,8 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_error(exception)
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -376,13 +341,16 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = str(uuid4())
|
||||
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
uuid=uuid,
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
@@ -393,8 +361,8 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_error(exception)
|
||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -405,37 +373,53 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
chain_uuid = "chain_uuid"
|
||||
tool_uuid = "tool_uuid"
|
||||
llm_uuid1 = "llm_uuid1"
|
||||
llm_uuid2 = "llm_uuid2"
|
||||
llm_uuid3 = "llm_uuid3"
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_tool_error(exception)
|
||||
tracer.on_chain_error(exception)
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_tool_start(
|
||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid
|
||||
)
|
||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
uuid=chain_uuid,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_runs=[
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid=llm_uuid1,
|
||||
parent_uuid=chain_uuid,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
@@ -443,36 +427,45 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid=llm_uuid2,
|
||||
parent_uuid=chain_uuid,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
id=None,
|
||||
uuid=tool_uuid,
|
||||
parent_uuid=chain_uuid,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{}",
|
||||
child_runs=[
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid=llm_uuid3,
|
||||
parent_uuid=tool_uuid,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
@@ -480,43 +473,10 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
)
|
||||
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.new_session()
|
||||
tracer.remove_runs()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run_multithreaded() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.remove_runs()
|
||||
tracer.new_session()
|
||||
threads = []
|
||||
num_threads = 10
|
||||
for _ in range(num_threads):
|
||||
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert tracer.runs == [_get_compare_run()] * num_threads
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Test LLM callbacks."""
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
Reference in New Issue
Block a user