Compare commits

...

12 Commits

Author SHA1 Message Date
Dev 2049
795f823bb1 rfc 2023-04-27 15:53:35 -07:00
Ankush Gola
6cd653deb4 cr 2023-04-27 14:16:31 -07:00
Ankush Gola
e953d2cf93 mypy 2023-04-27 12:26:58 -07:00
Ankush Gola
50668693d7 fix execution order issue 2023-04-26 19:19:39 -07:00
Ankush Gola
6fec15b6fb write to different session 2023-04-26 11:37:36 -07:00
Ankush Gola
7bcdc66b99 fix notebook and warnings 2023-04-26 11:34:21 -07:00
Ankush Gola
4cdd19bd4e Callbacks Refactor [2/n] update tracer to work with new callbacks mechanism (#3381) 2023-04-25 18:20:16 -07:00
Ankush Gola
90cef7b53a cr 2023-04-22 21:49:34 -07:00
Ankush Gola
675e27c136 Callbacks Refactor [2/n]: refactor CallbackManager code to own file (#3341) 2023-04-22 21:40:59 -07:00
Ankush Gola
fa4a4f2940 cr 2023-04-20 17:06:00 -07:00
Ankush Gola
55c7964e4e Merge branch 'master' into ankush/callbacks-refactor 2023-04-20 17:01:32 -07:00
Ankush Gola
3cc2ce6ac9 callbacks changes 2023-04-20 16:58:04 -07:00
97 changed files with 3667 additions and 2051 deletions

View File

@@ -28,13 +28,15 @@ To stream the model's predictions, add in a CallbackManager.
```python
from langchain.llms import GPT4All
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# There are many CallbackHandlers supported, such as
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, verbose=True)
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler,
verbose=True)
# Generate text. Tokens are streamed through the callback manager.
model("Once upon a time, ")

View File

@@ -17,33 +17,7 @@
"source": [
"LangChain provides a callback system that allows you to hook into the various stages of your LLM application. This is useful for logging, [monitoring](https://python.langchain.com/en/latest/tracing.html), [streaming](https://python.langchain.com/en/latest/modules/models/llms/examples/streaming_llm.html), and other tasks.\n",
"\n",
"You can subscribe to these events by using the `callback_manager` argument available throughout the API. A `CallbackManager` is an object that manages a list of `CallbackHandlers`. The `CallbackManager` will call the appropriate method on each handler when the event is triggered."
]
},
{
"cell_type": "markdown",
"id": "fdb72e8d-a02a-474d-96bf-f5759432afc8",
"metadata": {
"tags": []
},
"source": [
"```python\n",
"class CallbackManager(BaseCallbackHandler):\n",
" \"\"\"Base callback manager that can be used to handle callbacks from LangChain.\"\"\"\n",
"\n",
" def add_handler(self, callback: BaseCallbackHandler) -> None:\n",
" \"\"\"Add a handler to the callback manager.\"\"\"\n",
"\n",
" def remove_handler(self, handler: BaseCallbackHandler) -> None:\n",
" \"\"\"Remove a handler from the callback manager.\"\"\"\n",
"\n",
" def set_handler(self, handler: BaseCallbackHandler) -> None:\n",
" \"\"\"Set handler as the only handler on the callback manager.\"\"\"\n",
" self.set_handlers([handler])\n",
"\n",
" def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:\n",
" \"\"\"Set handlers as the only handlers on the callback manager.\"\"\"\n",
"```"
"You can subscribe to these events by using the `callbacks` argument available throughout the API. This argument list of handler objects, which are expected to implement one or more of the methods described in the API docs."
]
},
{
@@ -62,70 +36,57 @@
},
"source": [
"```python\n",
"class BaseCallbackHandler(ABC):\n",
"class BaseCallbackHandler:\n",
" \"\"\"Base callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_llm_start(\n",
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run when LLM starts running.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n",
" \"\"\"Run on new LLM token. Only available when streaming is enabled.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:\n",
" \"\"\"Run when LLM ends running.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_llm_error(\n",
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run when LLM errors.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_chain_start(\n",
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run when chain starts running.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:\n",
" \"\"\"Run when chain ends running.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_chain_error(\n",
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run when chain errors.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_tool_start(\n",
" self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run when tool starts running.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_tool_end(self, output: str, **kwargs: Any) -> Any:\n",
" \"\"\"Run when tool ends running.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_tool_error(\n",
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run when tool errors.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_text(self, text: str, **kwargs: Any) -> Any:\n",
" \"\"\"Run on arbitrary text.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n",
" \"\"\"Run on agent action.\"\"\"\n",
"\n",
" @abstractmethod\n",
" def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:\n",
" \"\"\"Run on agent end.\"\"\"\n",
"```"
@@ -136,9 +97,11 @@
"id": "d3bf3304-43fb-47ad-ae50-0637a17018a2",
"metadata": {},
"source": [
"## Creating and Using a Custom `CallbackHandler`\n",
"## Using an existing handler\n",
"\n",
"By default, a shared CallbackManager with the StdOutCallbackHandler will be used by models, chains, agents, and tools. However, you can pass in your own CallbackManager with a custom CallbackHandler:"
"LangChain provides a few built-in handlers that you can use to get started. These are available in the `langchain/callbacks` module. The most basic handler is the `StdOutCallbackHandler`, which simply logs all events to `stdout`. In the future we will add more default handlers to the library. \n",
"\n",
"**Note** when the `verbose` flag on the object is set to true, the `StdOutCallbackHandler` will be invoked even without being explicitly passed in."
]
},
{
@@ -155,16 +118,16 @@
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"AgentAction(tool='Search', tool_input=\"US Open men's final 2019 winner\", log=' I need to find out who won the US Open men\\'s final in 2019 and then calculate his age raised to the 0.334 power.\\nAction: Search\\nAction Input: \"US Open men\\'s final 2019 winner\"')\n",
"Rafael Nadal defeated Daniil Medvedev in the final, 75, 63, 57, 46, 64 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\n",
"AgentAction(tool='Search', tool_input='Rafael Nadal age', log=' I need to find out the age of the winner\\nAction: Search\\nAction Input: \"Rafael Nadal age\"')\n",
"36 years\n",
"AgentAction(tool='Calculator', tool_input='36^0.334', log=' I now need to calculate his age raised to the 0.334 power\\nAction: Calculator\\nAction Input: 36^0.334')\n",
"Answer: 3.3098250249682484\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n",
"\n",
" I now know the final answer\n",
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -172,7 +135,7 @@
{
"data": {
"text/plain": [
"\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\""
"'\\n\\n3'"
]
},
"execution_count": 1,
@@ -181,108 +144,89 @@
}
],
"source": [
"from typing import Any, Dict, List, Optional, Union\n",
"\n",
"from langchain.agents import initialize_agent, load_tools\n",
"from langchain.agents import AgentType\n",
"from langchain.callbacks.base import CallbackManager, BaseCallbackHandler\n",
"from langchain.callbacks import StdOutCallbackHandler\n",
"from langchain.chains import LLMChain\n",
"from langchain.llms import OpenAI\n",
"from langchain.schema import AgentAction, AgentFinish, LLMResult\n",
"from langchain.prompts import PromptTemplate\n",
"\n",
"class MyCustomCallbackHandler(BaseCallbackHandler):\n",
" \"\"\"Custom CallbackHandler.\"\"\"\n",
"handler = StdOutCallbackHandler()\n",
"llm = OpenAI()\n",
"prompt = PromptTemplate.from_template(\"1 + {number} = \")\n",
"\n",
" def on_llm_start(\n",
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Print out the prompts.\"\"\"\n",
" pass\n",
"# First, let's explicitly set the StdOutCallbackHandler in `callbacks`\n",
"chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])\n",
"chain.run(number=2)\n",
"\n",
" def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n",
" \"\"\"Do nothing.\"\"\"\n",
" pass\n",
"# Then, let's use the `verbose` flag to achieve the same result\n",
"chain = LLMChain(llm=llm, prompt=prompt, verbose=True)\n",
"chain.run(number=2)"
]
},
{
"cell_type": "markdown",
"id": "389c8448-5283-49e3-8c04-dbe1522e202c",
"metadata": {},
"source": [
"## Creating a custom handler\n",
"\n",
" def on_llm_new_token(self, token: str, **kwargs: Any) -> None:\n",
" \"\"\"Do nothing.\"\"\"\n",
" pass\n",
"You can create a custom handler to set on the object as well. In the example below, we'll implement streaming with a custom handler."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1b2e6588-0681-4cab-937a-7cc4790cea9a",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"My custom handler, token: \n",
"My custom handler, token: Why\n",
"My custom handler, token: did\n",
"My custom handler, token: the\n",
"My custom handler, token: tomato\n",
"My custom handler, token: turn\n",
"My custom handler, token: red\n",
"My custom handler, token: ?\n",
"My custom handler, token: Because\n",
"My custom handler, token: it\n",
"My custom handler, token: saw\n",
"My custom handler, token: the\n",
"My custom handler, token: salad\n",
"My custom handler, token: dressing\n",
"My custom handler, token: !\n",
"My custom handler, token: \n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content='Why did the tomato turn red? Because it saw the salad dressing!', additional_kwargs={})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.callbacks.base import BaseCallbackHandler\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema import HumanMessage\n",
"\n",
" def on_llm_error(\n",
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Do nothing.\"\"\"\n",
" pass\n",
"class MyCustomHandler(BaseCallbackHandler):\n",
" def on_llm_new_token(self, token: str, **kwargs) -> None:\n",
" print(f\"My custom handler, token: {token}\")\n",
"\n",
" def on_chain_start(\n",
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Print out that we are entering a chain.\"\"\"\n",
" class_name = serialized[\"name\"]\n",
" print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n",
"# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n",
"# Additionally, we pass in a list with our custom handler\n",
"chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomHandler()])\n",
"\n",
" def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n",
" \"\"\"Print out that we finished a chain.\"\"\"\n",
" print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n",
"\n",
" def on_chain_error(\n",
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Do nothing.\"\"\"\n",
" pass\n",
"\n",
" def on_tool_start(\n",
" self,\n",
" serialized: Dict[str, Any],\n",
" input_str: str,\n",
" **kwargs: Any,\n",
" ) -> None:\n",
" \"\"\"Do nothing.\"\"\"\n",
" pass\n",
"\n",
" def on_agent_action(\n",
" self, action: AgentAction, color: Optional[str] = None, **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run on agent action.\"\"\"\n",
" print(action)\n",
"\n",
" def on_tool_end(\n",
" self,\n",
" output: str,\n",
" color: Optional[str] = None,\n",
" observation_prefix: Optional[str] = None,\n",
" llm_prefix: Optional[str] = None,\n",
" **kwargs: Any,\n",
" ) -> None:\n",
" \"\"\"If not the final action, print out observation.\"\"\"\n",
" print(output)\n",
"\n",
" def on_tool_error(\n",
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Do nothing.\"\"\"\n",
" pass\n",
"\n",
" def on_text(\n",
" self,\n",
" text: str,\n",
" color: Optional[str] = None,\n",
" end: str = \"\",\n",
" **kwargs: Optional[str],\n",
" ) -> None:\n",
" \"\"\"Run when agent ends.\"\"\"\n",
" print(text)\n",
"\n",
" def on_agent_finish(\n",
" self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Run on agent end.\"\"\"\n",
" print(finish.log)\n",
"manager = CallbackManager([MyCustomCallbackHandler()])\n",
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)\n",
"tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, callback_manager=manager)\n",
"agent = initialize_agent(\n",
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager\n",
")\n",
"agent.run(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")"
"chat([HumanMessage(content=\"Tell me a joke\")])"
]
},
{
@@ -292,9 +236,11 @@
"tags": []
},
"source": [
"## Async Support\n",
"## Async Callbacks\n",
"\n",
"If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` and `AsyncCallbackManager` to avoid blocking the runloop."
"If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` to avoid blocking the runloop. \n",
"\n",
"**Advanced** if you use a sync `CallbackHandler` while using an async method to run your llm/chain/tool/agent, it will still work. However, under the hood, it will be called with [`run_in_executor`](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor) which can cause issues if your `CallbackHandler` is not thread-safe."
]
},
{
@@ -310,58 +256,589 @@
"output_type": "stream",
"text": [
"zzzz....\n",
"Hi! I just woke up. Your llm is starting\n",
"Sync handler being called in a `thread_pool_executor`: token: \n",
"Sync handler being called in a `thread_pool_executor`: token: Why\n",
"Sync handler being called in a `thread_pool_executor`: token: don\n",
"Sync handler being called in a `thread_pool_executor`: token: 't\n",
"Sync handler being called in a `thread_pool_executor`: token: scientists\n",
"Sync handler being called in a `thread_pool_executor`: token: trust\n",
"Sync handler being called in a `thread_pool_executor`: token: atoms\n",
"Sync handler being called in a `thread_pool_executor`: token: ?\n",
"\n",
"\n",
"Sync handler being called in a `thread_pool_executor`: token: Because\n",
"Sync handler being called in a `thread_pool_executor`: token: they\n",
"Sync handler being called in a `thread_pool_executor`: token: make\n",
"Sync handler being called in a `thread_pool_executor`: token: up\n",
"Sync handler being called in a `thread_pool_executor`: token: everything\n",
"Sync handler being called in a `thread_pool_executor`: token: !\n",
"Sync handler being called in a `thread_pool_executor`: token: \n",
"zzzz....\n",
"Hi! I just woke up. Your llm is ending\n"
]
},
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", generation_info=None, message=AIMessage(content=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", additional_kwargs={}))]], llm_output={'token_usage': {}, 'model_name': 'gpt-3.5-turbo'})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import asyncio\n",
"from typing import Any, Dict, List\n",
"from langchain.schema import LLMResult\n",
"from langchain.callbacks.base import AsyncCallbackHandler\n",
"\n",
"class MyCustomSyncHandler(BaseCallbackHandler):\n",
" def on_llm_new_token(self, token: str, **kwargs) -> None:\n",
" print(f\"Sync handler being called in a `thread_pool_executor`: token: {token}\")\n",
"\n",
"class MyCustomAsyncHandler(AsyncCallbackHandler):\n",
" \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
"\n",
" async def on_llm_start(\n",
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Run when chain starts running.\"\"\"\n",
" print(\"zzzz....\")\n",
" await asyncio.sleep(0.3)\n",
" class_name = serialized[\"name\"]\n",
" print(\"Hi! I just woke up. Your llm is starting\")\n",
"\n",
" async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n",
" \"\"\"Run when chain ends running.\"\"\"\n",
" print(\"zzzz....\")\n",
" await asyncio.sleep(0.3)\n",
" print(\"Hi! I just woke up. Your llm is ending\")\n",
"\n",
"# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n",
"# Additionally, we pass in a list with our custom handler\n",
"chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomSyncHandler(), MyCustomAsyncHandler()])\n",
"\n",
"await chat.agenerate([[HumanMessage(content=\"Tell me a joke\")]])"
]
},
{
"cell_type": "markdown",
"id": "d26dbb34-fcc3-401c-a115-39c7620d2d65",
"metadata": {},
"source": [
"## Using multiple handlers, passing in handlers\n",
"\n",
"In the previous examples, we passed in callback handlers upon creation of an object by using `callbacks=`. In this case, the callbacks will be scoped to that particular object. \n",
"\n",
"However, in many cases, it is advantageous to pass in handlers instead when running the object. When we pass through `CallbackHandlers` using the `callbacks` keyword arg when executing an run, those callbacks will be issued by all nested objects involved in the execution. For example, when a handler is passed through to an `Agent`, it will be used for all callbacks related to the agent and all the objects involved in the agent's execution, in this case, the `Tools`, `LLMChain`, and `LLM`.\n",
"\n",
"This prevents us from having to manually attach the handlers to each individual nested object."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8eec8756-1828-45cb-9699-38ac8543a150",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"on_chain_start AgentExecutor\n",
"on_chain_start LLMChain\n",
"on_llm_start OpenAI\n",
"on_llm_start (I'm the second handler!!) OpenAI\n",
"on_new_token I\n",
"on_new_token need\n",
"on_new_token to\n",
"on_new_token use\n",
"on_new_token a\n",
"on_new_token calculator\n",
"on_new_token to\n",
"on_new_token solve\n",
"on_new_token this\n",
"on_new_token .\n",
"on_new_token \n",
"Action\n",
"on_new_token :\n",
"on_new_token Calculator\n",
"on_new_token \n",
"Action\n",
"on_new_token Input\n",
"on_new_token :\n",
"on_new_token 2\n",
"on_new_token ^\n",
"on_new_token 0\n",
"on_new_token .\n",
"on_new_token 235\n",
"on_new_token \n",
"on_agent_action AgentAction(tool='Calculator', tool_input='2^0.235', log=' I need to use a calculator to solve this.\\nAction: Calculator\\nAction Input: 2^0.235')\n",
"on_tool_start Calculator\n",
"on_chain_start LLMMathChain\n",
"on_chain_start LLMChain\n",
"on_llm_start OpenAI\n",
"on_llm_start (I'm the second handler!!) OpenAI\n",
"on_new_token \n",
"\n",
"on_new_token ```text\n",
"on_new_token \n",
"\n",
"on_new_token 2\n",
"on_new_token **\n",
"on_new_token 0\n",
"on_new_token .\n",
"on_new_token 235\n",
"on_new_token \n",
"\n",
"on_new_token ```\n",
"\n",
"on_new_token ...\n",
"on_new_token num\n",
"on_new_token expr\n",
"on_new_token .\n",
"on_new_token evaluate\n",
"on_new_token (\"\n",
"on_new_token 2\n",
"on_new_token **\n",
"on_new_token 0\n",
"on_new_token .\n",
"on_new_token 235\n",
"on_new_token \")\n",
"on_new_token ...\n",
"on_new_token \n",
"\n",
"on_new_token \n",
"on_chain_start LLMChain\n",
"on_llm_start OpenAI\n",
"on_llm_start (I'm the second handler!!) OpenAI\n",
"on_new_token I\n",
"on_new_token now\n",
"on_new_token know\n",
"on_new_token the\n",
"on_new_token final\n",
"on_new_token answer\n",
"on_new_token .\n",
"on_new_token \n",
"Final\n",
"on_new_token Answer\n",
"on_new_token :\n",
"on_new_token 1\n",
"on_new_token .\n",
"on_new_token 17\n",
"on_new_token 690\n",
"on_new_token 67\n",
"on_new_token 372\n",
"on_new_token 187\n",
"on_new_token 674\n",
"on_new_token \n"
]
},
{
"data": {
"text/plain": [
"'1.1769067372187674'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import Dict, Union, Any, List\n",
"\n",
"from langchain.callbacks.base import BaseCallbackHandler\n",
"from langchain.schema import AgentAction\n",
"from langchain.agents import AgentType, initialize_agent, load_tools\n",
"from langchain.callbacks import tracing_enabled\n",
"from langchain.llms import OpenAI\n",
"\n",
"# First, define custom callback handler implementations\n",
"class MyCustomHandlerOne(BaseCallbackHandler):\n",
" def on_llm_start(\n",
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
" ) -> Any:\n",
" print(f\"on_llm_start {serialized['name']}\")\n",
"\n",
" def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n",
" print(f\"on_new_token {token}\")\n",
"\n",
" def on_llm_error(\n",
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
" ) -> Any:\n",
" \"\"\"Run when LLM errors.\"\"\"\n",
"\n",
" def on_chain_start(\n",
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
" ) -> Any:\n",
" print(f\"on_chain_start {serialized['name']}\")\n",
"\n",
" def on_tool_start(\n",
" self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n",
" ) -> Any:\n",
" print(f\"on_tool_start {serialized['name']}\")\n",
"\n",
" def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n",
" print(f\"on_agent_action {action}\")\n",
"\n",
"class MyCustomHandlerTwo(BaseCallbackHandler):\n",
" def on_llm_start(\n",
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
" ) -> Any:\n",
" print(f\"on_llm_start (I'm the second handler!!) {serialized['name']}\")\n",
"\n",
"# Instantiate the handlers\n",
"handler1 = MyCustomHandlerOne()\n",
"handler2 = MyCustomHandlerTwo()\n",
"\n",
"# Setup the agent. Only the `llm` will issue callbacks for handler2\n",
"llm = OpenAI(temperature=0, streaming=True, callbacks=[handler2])\n",
"tools = load_tools([\"llm-math\"], llm=llm)\n",
"agent = initialize_agent(\n",
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n",
")\n",
"\n",
"# Callbacks for handler1 will be issued by every object involved in the \n",
"# Agent execution (llm, llmchain, tool, agent executor)\n",
"agent.run(\"What is 2 raised to the 0.235 power?\", callbacks=[handler1])"
]
},
{
"cell_type": "markdown",
"id": "32b29135-f852-4492-88ed-547275c72c53",
"metadata": {},
"source": [
"# Tracing and Token Counting"
]
},
{
"cell_type": "markdown",
"id": "fbb606d6-2863-46c5-8347-9f0bdb3805bb",
"metadata": {},
"source": [
"Tracing and token counting are two capabilities we provide which are built on our callbacks mechanism."
]
},
{
"cell_type": "markdown",
"id": "f62cd10c-494c-47d6-aa98-6e926cb9c456",
"metadata": {},
"source": [
"## Tracing"
]
},
{
"cell_type": "markdown",
"id": "d5a74b3f-3769-4a4f-99c7-b6a3b20a94e2",
"metadata": {},
"source": [
"There are two recommended ways to trace your LangChains:\n",
"\n",
"1. Setting the `LANGCHAIN_TRACING` environment variable to `\"true\"`. \n",
"2. Using a context manager `with tracing_enabled()` to trace a particular block of code.\n",
"\n",
"**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f164dfd5-d987-4b6a-a7c8-019c651ce47f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"\n",
"from langchain.agents import AgentType, initialize_agent, load_tools\n",
"from langchain.callbacks import tracing_enabled\n",
"from langchain.llms import OpenAI\n",
"\n",
"# To run the code, make sure to set OPENAI_API_KEY and SERPAPI_API_KEY\n",
"llm = OpenAI(temperature=0)\n",
"tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n",
"agent = initialize_agent(\n",
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n",
")\n",
"\n",
"questions = [\n",
" \"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\",\n",
" \"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\",\n",
" \"Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?\",\n",
" \"Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?\",\n",
" \"Who is Beyonce's husband? What is his age raised to the 0.19 power?\",\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6be7777e-ec1d-438f-ae33-3a93c45f808e",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"zzzz....\n",
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
"Action: Search\n",
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 75, 63, 57, 46, 64 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Action: Search\n",
"Action Input: \"Rafael Nadal age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n",
"Action: Calculator\n",
"Action Input: 36^0.334\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
"Action: Search\n",
"Action Input: \"Harry Styles age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
"Action: Calculator\n",
"Action Input: 29^0.23\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
}
],
"source": [
"import asyncio\n",
"from aiohttp import ClientSession\n",
"os.environ[\"LANGCHAIN_TRACING\"] = \"true\"\n",
"\n",
"from langchain.callbacks.base import AsyncCallbackHandler, AsyncCallbackManager\n",
"\n",
"class MyCustomAsyncCallbackHandler(AsyncCallbackHandler):\n",
" \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
"\n",
" async def on_chain_start(\n",
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Run when chain starts running.\"\"\"\n",
" print(\"zzzz....\")\n",
" await asyncio.sleep(0.5)\n",
" class_name = serialized[\"name\"]\n",
" print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n",
"\n",
" async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n",
" \"\"\"Run when chain ends running.\"\"\"\n",
" print(\"zzzz....\")\n",
" await asyncio.sleep(0.5)\n",
" print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n",
"\n",
"manager = AsyncCallbackManager([MyCustomAsyncCallbackHandler()])\n",
"\n",
"# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n",
"# but you must manually close the client session at the end of your program/event loop\n",
"aiosession = ClientSession()\n",
"llm = OpenAI(temperature=0, callback_manager=manager)\n",
"async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n",
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
"await async_agent.arun(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")\n",
"await aiosession.close()"
"# Both of the agent runs will be traced because the environment variable is set\n",
"agent.run(questions[0])\n",
"with tracing_enabled() as session:\n",
" assert session\n",
" agent.run(questions[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86be6304-e433-4048-880c-a92a73244407",
"execution_count": 10,
"id": "a6fd6026-dc1e-4d48-893d-3592539c7828",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
"Action: Search\n",
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 75, 63, 57, 46, 64 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Action: Search\n",
"Action Input: \"Rafael Nadal age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n",
"Action: Calculator\n",
"Action Input: 36^0.334\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
"Action: Search\n",
"Action Input: \"Harry Styles age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
"Action: Calculator\n",
"Action Input: 29^0.23\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Now, we unset the environment variable and use a context manager.\n",
"\n",
"if \"LANGCHAIN_TRACING\" in os.environ:\n",
" del os.environ[\"LANGCHAIN_TRACING\"]\n",
"\n",
"# here, we are writing traces to \"my_test_session\"\n",
"with tracing_enabled(\"my_test_session\") as session:\n",
" assert session\n",
" agent.run(questions[0]) # this should be traced\n",
"\n",
"agent.run(questions[1]) # this should not be traced"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9383a351-4983-44e9-abd7-ef942e1c65c4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
"Action: Search\n",
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 75, 63, 57, 46, 64 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\u001b[33;1m\u001b[1;3mLewis Hamilton has won 103 Grands Prix during his career. He won 21 races with McLaren and has won 82 with Mercedes. Lewis Hamilton holds the record for the ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Action: Search\n",
"Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[33;1m\u001b[1;3m36 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
"Action: Search\n",
"Action Input: \"Harry Styles age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Lewis Hamilton's age\n",
"Action: Search\n",
"Action Input: \"Lewis Hamilton Age\"\u001b[0m\u001b[33;1m\u001b[1;3m29 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n",
"Action: Calculator\n",
"Action Input: 36^0.334\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
"Action: Calculator\n",
"Action Input: 29^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\u001b[33;1m\u001b[1;3m38 years\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I now need to calculate 38 raised to the 0.23 power\n",
"Action: Calculator\n",
"Action Input: 38^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.3086081644669734\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The context manager is concurrency safe:\n",
"if \"LANGCHAIN_TRACING\" in os.environ:\n",
" del os.environ[\"LANGCHAIN_TRACING\"]\n",
"\n",
"# start a background task\n",
"task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced\n",
"with tracing_enabled() as session:\n",
" assert session\n",
" tasks = [agent.arun(q) for q in questions[1:3]] # these should be traced\n",
" await asyncio.gather(*tasks)\n",
"\n",
"await task"
]
},
{
"cell_type": "markdown",
"id": "254fef1b-6b6e-4352-9cf4-363fba895ac7",
"metadata": {},
"source": [
"## Token Counting\n",
"LangChain offers a context manager that allows you to count tokens."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5c3e0b89-2c5e-4036-bdf2-fb6b750e360c",
"metadata": {
"tags": []
},
"outputs": [],
"source": []
"source": [
"from langchain.callbacks import get_openai_callback\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"with get_openai_callback() as cb:\n",
" llm(\"What is the square root of 4?\")\n",
"\n",
"total_tokens = cb.total_tokens\n",
"assert total_tokens > 0\n",
"\n",
"with get_openai_callback() as cb:\n",
" llm(\"What is the square root of 4?\")\n",
" llm(\"What is the square root of 4?\")\n",
"\n",
"assert cb.total_tokens == total_tokens * 2\n",
"\n",
"# You can kick off concurrent runs from within the context manager\n",
"with get_openai_callback() as cb:\n",
" await asyncio.gather(\n",
" *[llm.agenerate([\"What is the square root of 4?\"]) for _ in range(3)]\n",
" )\n",
"\n",
"assert cb.total_tokens == total_tokens * 3\n",
"\n",
"# The context manager is concurrency safe\n",
"task = asyncio.create_task(llm.agenerate([\"What is the square root of 4?\"]))\n",
"with get_openai_callback() as cb:\n",
" await llm.agenerate([\"What is the square root of 4?\"])\n",
"\n",
"await task\n",
"assert cb.total_tokens == total_tokens"
]
}
],
"metadata": {

View File

@@ -5,11 +5,6 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.callbacks import (
set_default_callback_manager,
set_handler,
set_tracing_callback_manager,
)
from langchain.chains import (
ConversationChain,
LLMBashChain,
@@ -65,7 +60,6 @@ del metadata # optional, avoids polluting the results of dir(__package__)
verbose: bool = False
llm_cache: Optional[BaseCache] = None
set_default_callback_manager()
# For backwards compatibility
SerpAPIChain = SerpAPIWrapper
@@ -115,7 +109,5 @@ __all__ = [
"VectorDBQAWithSourcesChain",
"QAWithSourcesChain",
"PALChain",
"set_handler",
"set_tracing_callback_manager",
"LlamaCpp",
]

View File

@@ -13,7 +13,13 @@ import yaml
from pydantic import BaseModel, root_validator
from langchain.agents.tools import InvalidTool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
@@ -23,7 +29,6 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
)
@@ -46,13 +51,17 @@ class BaseSingleActionAgent(BaseModel):
@abstractmethod
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@@ -61,13 +70,17 @@ class BaseSingleActionAgent(BaseModel):
@abstractmethod
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@@ -170,13 +183,17 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@@ -185,13 +202,17 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@@ -285,38 +306,52 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = self.llm_chain.run(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
intermediate_steps=intermediate_steps,
stop=self.stop,
callbacks=callbacks,
**kwargs,
)
return self.output_parser.parse(output)
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = await self.llm_chain.arun(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
intermediate_steps=intermediate_steps,
stop=self.stop,
callbacks=callbacks,
**kwargs,
)
return self.output_parser.parse(output)
@@ -368,37 +403,45 @@ class Agent(BaseSingleActionAgent):
return thoughts
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(**full_inputs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = await self.llm_chain.apredict(**full_inputs)
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
def get_full_inputs(
@@ -632,24 +675,27 @@ class AgentExecutor(Chain):
return True
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
self.callback_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
def _return(
self,
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if run_manager:
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
async def _areturn(
self, output: AgentFinish, intermediate_steps: list
self,
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.callback_manager.is_async:
await self.callback_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
else:
self.callback_manager.on_agent_finish(
if run_manager:
await run_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
final_output = output.return_values
@@ -663,13 +709,18 @@ class AgentExecutor(Chain):
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
# Call the LLM to see what to do.
output = self.agent.plan(intermediate_steps, **inputs)
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
@@ -680,9 +731,8 @@ class AgentExecutor(Chain):
actions = output
result = []
for agent_action in actions:
self.callback_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
@@ -696,6 +746,7 @@ class AgentExecutor(Chain):
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
@@ -704,6 +755,7 @@ class AgentExecutor(Chain):
agent_action.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
result.append((agent_action, observation))
@@ -715,13 +767,18 @@ class AgentExecutor(Chain):
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
# Call the LLM to see what to do.
output = await self.agent.aplan(intermediate_steps, **inputs)
output = await self.agent.aplan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
@@ -734,12 +791,8 @@ class AgentExecutor(Chain):
async def _aperform_agent_action(
agent_action: AgentAction,
) -> Tuple[AgentAction, str]:
if self.callback_manager.is_async:
await self.callback_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
else:
self.callback_manager.on_agent_action(
if run_manager:
await run_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
@@ -755,6 +808,7 @@ class AgentExecutor(Chain):
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
@@ -763,6 +817,7 @@ class AgentExecutor(Chain):
agent_action.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return agent_action, observation
@@ -774,7 +829,11 @@ class AgentExecutor(Chain):
return list(result)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
@@ -790,10 +849,16 @@ class AgentExecutor(Chain):
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations, time_elapsed):
next_step_output = self._take_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager=run_manager,
)
if isinstance(next_step_output, AgentFinish):
return self._return(next_step_output, intermediate_steps)
return self._return(
next_step_output, intermediate_steps, run_manager=run_manager
)
intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1:
@@ -809,7 +874,11 @@ class AgentExecutor(Chain):
)
return self._return(output, intermediate_steps)
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
@@ -827,7 +896,11 @@ class AgentExecutor(Chain):
try:
while self._should_continue(iterations, time_elapsed):
next_step_output = await self._atake_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager=run_manager,
)
if isinstance(next_step_output, AgentFinish):
return await self._areturn(next_step_output, intermediate_steps)
@@ -845,7 +918,9 @@ class AgentExecutor(Chain):
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(output, intermediate_steps)
return await self._areturn(
output, intermediate_steps, run_manager=run_manager
)
except TimeoutError:
# stop early when interrupted by the async timeout
output = self.agent.return_stopped_response(

View File

@@ -26,12 +26,12 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.requests.tool import BaseRequestsTool

View File

@@ -5,6 +5,7 @@ from pydantic import Field
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.chat.output_parser import ChatOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
@@ -13,7 +14,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BaseLanguageModel
from langchain.schema import AgentAction
from langchain.tools import BaseTool

View File

@@ -9,10 +9,10 @@ from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent_types import AgentType
from langchain.agents.conversational.output_parser import ConvoOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
SUFFIX,
TEMPLATE_TOOL_RESPONSE,
)
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts.base import BasePromptTemplate
@@ -24,7 +25,6 @@ from langchain.prompts.chat import (
from langchain.schema import (
AgentAction,
AIMessage,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
HumanMessage,

View File

@@ -4,8 +4,8 @@ from typing import Any, Optional, Sequence
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@@ -103,8 +103,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool:
return Tool(
name="Calculator",
description="Useful for when you need to answer questions about math.",
func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run,
coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun,
func=LLMMathChain(llm=llm).run,
coroutine=LLMMathChain(llm=llm).arun,
)

View File

@@ -10,10 +10,10 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@@ -4,6 +4,11 @@ from typing import Any, Awaitable, Callable, Optional, Type, Union
from pydantic import BaseModel, validate_arguments
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
Callbacks,
)
from langchain.tools.base import BaseTool
@@ -26,14 +31,42 @@ class Tool(BaseTool):
valid_keys = signature(self.func).parameters
return {k: schema[k] for k in valid_keys}
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
"""Use the tool."""
return self.func(*args, **kwargs)
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
"""Use the tool asynchronously."""
new_argument_supported = signature(self.coroutine).parameters.get("callbacks")
if self.coroutine:
return await self.coroutine(*args, **kwargs)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future

View File

@@ -0,0 +1,55 @@
"""Base class for all language models."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])

View File

@@ -1,80 +1,27 @@
"""Callback handlers that allow listening to events in LangChain."""
import os
from contextlib import contextmanager
from typing import Generator, Optional
from typing import Generator
from langchain.callbacks.aim_callback import AimCallbackHandler
from langchain.callbacks.base import (
AsyncCallbackManager,
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
from langchain.callbacks.manager import (
CallbackManager,
get_openai_callback,
tracing_enabled,
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.tracers import SharedLangChainTracer
from langchain.callbacks.tracers import LangChainTracer
from langchain.callbacks.wandb_callback import WandbCallbackHandler
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()
def set_handler(handler: BaseCallbackHandler) -> None:
"""Set handler."""
callback = get_callback_manager()
callback.set_handler(handler)
def set_default_callback_manager() -> None:
"""Set default callback manager."""
default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
if default_handler == "stdout":
set_handler(StdOutCallbackHandler())
elif default_handler == "langchain":
session = os.environ.get("LANGCHAIN_SESSION")
set_tracing_callback_manager(session)
else:
raise ValueError(
f"LANGCHAIN_HANDLER should be one of `stdout` "
f"or `langchain`, got {default_handler}"
)
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
"""Set tracing callback manager."""
handler = SharedLangChainTracer()
callback = get_callback_manager()
callback.set_handlers([handler, StdOutCallbackHandler()])
if session_name is None:
handler.load_default_session()
else:
try:
handler.load_session(session_name)
except Exception:
raise ValueError(f"session {session_name} not found")
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get OpenAI callback handler in a context manager."""
handler = OpenAICallbackHandler()
manager = get_callback_manager()
manager.add_handler(handler)
yield handler
manager.remove_handler(handler)
__all__ = [
"CallbackManager",
"AsyncCallbackManager",
"OpenAICallbackHandler",
"SharedCallbackManager",
"StdOutCallbackHandler",
"AimCallbackHandler",
"WandbCallbackHandler",
@@ -82,8 +29,5 @@ __all__ = [
"CometCallbackHandler",
"AsyncIteratorCallbackHandler",
"get_openai_callback",
"set_tracing_callback_manager",
"set_default_callback_manager",
"set_handler",
"get_callback_manager",
"tracing_enabled",
]

View File

@@ -1,19 +1,173 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
import asyncio
import functools
from abc import ABC, abstractmethod
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
import copy
from typing import Any, Dict, List, Optional, Union
from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(ABC):
"""Base callback handler that can be used to handle callbacks from langchain."""
class LLMManagerMixin:
"""Mixin for LLM callbacks."""
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return False
def on_llm_new_token(
self,
token: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
def on_llm_end(
self,
response: LLMResult,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM ends running."""
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM errors."""
class ChainManagerMixin:
"""Mixin for chain callbacks."""
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running."""
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors."""
def on_agent_action(
self,
action: AgentAction,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
class ToolManagerMixin:
"""Mixin for tool callbacks."""
def on_tool_end(
self,
output: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when tool ends running."""
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors."""
class CallbackManagerMixin:
"""Mixin for callback manager."""
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
class RunManagerMixin:
"""Mixin for run manager."""
def on_text(
self,
text: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run on arbitrary text."""
class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
CallbackManagerMixin,
RunManagerMixin,
):
"""Base callback handler that can be used to handle callbacks from langchain."""
@property
def ignore_llm(self) -> bool:
@@ -30,480 +184,197 @@ class BaseCallbackHandler(ABC):
"""Whether to ignore agent callbacks."""
return False
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
@abstractmethod
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
async def on_llm_new_token(
self,
token: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
@abstractmethod
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
async def on_llm_end(
self,
response: LLMResult,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
async def on_tool_end(
self,
output: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
@abstractmethod
def on_text(self, text: str, **kwargs: Any) -> Any:
async def on_text(
self,
text: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run on arbitrary text."""
@abstractmethod
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
async def on_agent_action(
self,
action: AgentAction,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run on agent action."""
@abstractmethod
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
async def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run on agent end."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that can be used to handle callbacks from LangChain."""
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[str] = None,
) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = (
inheritable_handlers or []
)
self.parent_run_id: Optional[str] = parent_run_id
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""
return False
@abstractmethod
def add_handler(self, callback: BaseCallbackHandler) -> None:
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
if inherit:
self.inheritable_handlers.append(handler)
@abstractmethod
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
self.inheritable_handlers.remove(handler)
def set_handler(self, handler: BaseCallbackHandler) -> None:
def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True
) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
self.add_handler(handler, inherit=inherit)
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager."""
self.set_handlers([handler])
self.set_handlers([handler], inherit=inherit)
@abstractmethod
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_start(serialized, prompts, **kwargs)
def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM generates a new token."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_new_token(token, **kwargs)
def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_end(response)
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_error(error)
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_end(outputs)
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain errors."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_end(output, **kwargs)
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool errors."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_error(error)
def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run on additional input from chains and agents."""
for handler in self.handlers:
if verbose or handler.always_verbose:
handler.on_text(text, **kwargs)
def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent end."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_finish(finish, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
async def _handle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
if verbose or handler.always_verbose:
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
)
class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that can be used to handle callbacks from LangChain."""
@property
def is_async(self) -> bool:
"""Return whether the handler is async."""
return True
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
async def _handle_event(
self,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
"""Generic event handler for AsyncCallbackManager."""
await asyncio.gather(
*(
_handle_event_for_handler(
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
)
for handler in self.handlers
)
def __copy__(self) -> "BaseCallbackManager":
return self.__class__(
self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id
)
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
await self._handle_event(
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
def __deepcopy__(self, memo: dict) -> "BaseCallbackManager":
return self.__class__(
[copy.deepcopy(handler, memo) for handler in self.handlers],
[copy.deepcopy(handler, memo) for handler in self.inheritable_handlers],
self.parent_run_id,
)
async def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
await self._handle_event(
"on_llm_new_token", "ignore_llm", verbose, token, **kwargs
)
async def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM ends running."""
await self._handle_event(
"on_llm_end", "ignore_llm", verbose, response, **kwargs
)
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM errors."""
await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain starts running."""
await self._handle_event(
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
)
async def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None:
"""Run when chain ends running."""
await self._handle_event(
"on_chain_end", "ignore_chain", verbose, outputs, **kwargs
)
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain errors."""
await self._handle_event(
"on_chain_error", "ignore_chain", verbose, error, **kwargs
)
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool starts running."""
await self._handle_event(
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
)
async def on_tool_end(
self, output: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool ends running."""
await self._handle_event(
"on_tool_end", "ignore_agent", verbose, output, **kwargs
)
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool errors."""
await self._handle_event(
"on_tool_error", "ignore_agent", verbose, error, **kwargs
)
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when text is printed."""
await self._handle_event("on_text", None, verbose, text, **kwargs)
async def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent action."""
await self._handle_event(
"on_agent_action", "ignore_agent", verbose, action, **kwargs
)
async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when agent finishes."""
await self._handle_event(
"on_agent_finish", "ignore_agent", verbose, finish, **kwargs
)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers

View File

@@ -0,0 +1,770 @@
from __future__ import annotations
import asyncio
import copy
import functools
import os
import uuid
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
ChainManagerMixin,
LLMManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.schema import AgentAction, AgentFinish, LLMResult
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
"tracing_callback", default=None
)
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get OpenAI callback handler in a context manager."""
cb = OpenAICallbackHandler()
openai_callback_var.set(cb)
yield cb
openai_callback_var.set(None)
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
"""Get OpenAI callback handler in a context manager."""
cb = LangChainTracer()
session = cb.load_session(session_name)
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
def _handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
for handler in handlers:
try:
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
getattr(handler, event_name)(*args, **kwargs)
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
async def _ahandle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
try:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
)
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
async def _ahandle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for AsyncCallbackManager."""
await asyncio.gather(
*(
_ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
)
for handler in handlers
)
)
class BaseRunManager(RunManagerMixin):
"""Base class for run manager (a bound callback manager)."""
def __init__(
self,
run_id: str,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler],
parent_run_id: Optional[str] = None,
) -> None:
"""Initialize run manager."""
self.run_id = run_id
self.handlers = handlers
self.inheritable_handlers = inheritable_handlers
self.parent_run_id = parent_run_id
class RunManager(BaseRunManager):
"""Sync Run Manager."""
def on_text(
self,
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received."""
_handle_event(
self.handlers,
"on_text",
None,
text,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncRunManager(BaseRunManager):
"""Async Run Manager."""
async def on_text(
self,
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received."""
await _ahandle_event(
self.handlers,
"on_text",
None,
text,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
"""Callback manager for LLM run."""
def on_llm_new_token(
self,
token: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token."""
_handle_event(
self.handlers,
"on_llm_new_token",
"ignore_llm",
token=token,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
_handle_event(
self.handlers,
"on_llm_end",
"ignore_llm",
response,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
_handle_event(
self.handlers,
"on_llm_error",
"ignore_llm",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""Async callback manager for LLM run."""
async def on_llm_new_token(
self,
token: str,
*,
run_id: Optional[str] = None,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token."""
await _ahandle_event(
self.handlers,
"on_llm_new_token",
"ignore_llm",
token,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
await _ahandle_event(
self.handlers,
"on_llm_end",
"ignore_llm",
response,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
await _ahandle_event(
self.handlers,
"on_llm_error",
"ignore_llm",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
"""Callback manager for chain run."""
def get_child(self) -> Callbacks:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
_handle_event(
self.handlers,
"on_chain_end",
"ignore_chain",
outputs,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when chain errors."""
_handle_event(
self.handlers,
"on_chain_error",
"ignore_chain",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received."""
_handle_event(
self.handlers,
"on_agent_action",
"ignore_agent",
action,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received."""
_handle_event(
self.handlers,
"on_agent_finish",
"ignore_agent",
finish,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
"""Async callback manager for chain run."""
def get_child(self) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
await _ahandle_event(
self.handlers,
"on_chain_end",
"ignore_chain",
outputs,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when chain errors."""
await _ahandle_event(
self.handlers,
"on_chain_error",
"ignore_chain",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received."""
await _ahandle_event(
self.handlers,
"on_agent_action",
"ignore_agent",
action,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received."""
await _ahandle_event(
self.handlers,
"on_agent_finish",
"ignore_agent",
finish,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
"""Callback manager for tool run."""
def get_child(self) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
def on_tool_end(
self,
output: str,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
_handle_event(
self.handlers,
"on_tool_end",
"ignore_agent",
output,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when tool errors."""
_handle_event(
self.handlers,
"on_tool_error",
"ignore_agent",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
"""Async callback manager for tool run."""
def get_child(self) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
await _ahandle_event(
self.handlers,
"on_tool_end",
"ignore_agent",
output,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when tool errors."""
await _ahandle_event(
self.handlers,
"on_tool_error",
"ignore_agent",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[str] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
"""Run when LLM starts running."""
if run_id is None:
run_id = str(uuid.uuid4())
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
run_id: Optional[str] = None,
**kwargs: Any,
) -> CallbackManagerForChainRun:
"""Run when chain starts running."""
if run_id is None:
run_id = str(uuid.uuid4())
_handle_event(
self.handlers,
"on_chain_start",
"ignore_chain",
serialized,
inputs,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return CallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
run_id: Optional[str] = None,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> CallbackManagerForToolRun:
"""Run when tool starts running."""
if run_id is None:
run_id = str(uuid.uuid4())
_handle_event(
self.handlers,
"on_tool_start",
"ignore_agent",
serialized,
input_str,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return CallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,
inheritable_callbacks: Optional[
Union[BaseCallbackManager, List[BaseCallbackHandler]]
] = None,
local_callbacks: Optional[
Union[BaseCallbackManager, List[BaseCallbackHandler]]
] = None,
verbose: bool = False,
) -> Optional[BaseCallbackManager]:
"""Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that can be used to handle callbacks from LangChain."""
@property
def is_async(self) -> bool:
"""Return whether the handler is async."""
return True
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[str] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForLLMRun:
"""Run when LLM starts running."""
if run_id is None:
run_id = str(uuid.uuid4())
await _ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
run_id: Optional[str] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForChainRun:
"""Run when chain starts running."""
if run_id is None:
run_id = str(uuid.uuid4())
await _ahandle_event(
self.handlers,
"on_chain_start",
"ignore_chain",
serialized,
inputs,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
run_id: Optional[str] = None,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForToolRun:
"""Run when tool starts running."""
if run_id is None:
run_id = str(uuid.uuid4())
await _ahandle_event(
self.handlers,
"on_tool_start",
"ignore_agent",
serialized,
input_str,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,
inheritable_callbacks: Optional[
Union[BaseCallbackManager, List[BaseCallbackHandler]]
] = None,
local_callbacks: Optional[
Union[BaseCallbackManager, List[BaseCallbackHandler]]
] = None,
verbose: bool = False,
) -> Optional[BaseCallbackManager]:
"""Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
def _configure(
callback_manager_cls: Type[T],
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
) -> T:
"""Configure the callback manager."""
callback_manager = callback_manager_cls([])
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
inheritable_callbacks_: List[BaseCallbackHandler] = (
inheritable_callbacks or []
)
callback_manager = callback_manager_cls(
handlers=inheritable_callbacks_,
inheritable_handlers=inheritable_callbacks_,
)
else:
callback_manager = callback_manager_cls(
handlers=inheritable_callbacks.handlers,
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
parent_run_id=inheritable_callbacks.parent_run_id,
)
callback_manager = copy.deepcopy(callback_manager)
local_handlers_ = (
local_callbacks
if isinstance(local_callbacks, list)
else (local_callbacks.handlers if local_callbacks else [])
)
for handler in local_handlers_:
callback_manager.add_handler(copy.deepcopy(handler), False)
tracer = tracing_callback_var.get()
open_ai = openai_callback_var.get()
tracing_enabled_ = (
os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None
)
if verbose or tracing_enabled_ or open_ai is not None:
if verbose and not any(
isinstance(handler, StdOutCallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(StdOutCallbackHandler(), False)
if tracing_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
if tracer:
callback_manager.add_handler(copy.deepcopy(tracer), True)
else:
handler = LangChainTracer()
handler.load_default_session()
callback_manager.add_handler(handler, True)
if open_ai is not None and not any(
isinstance(handler, OpenAICallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(open_ai, True)
return callback_manager
class NullCallbackManagerForChainRun(CallbackManagerForChainRun):
def __init__(self) -> None:
super().__init__("", [], [])
def on_text(
self,
text: str,
**kwargs: Any,
) -> Any:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
pass
def get_child(self) -> Callbacks:
return None

View File

@@ -1,127 +0,0 @@
"""A shared CallbackManager."""
import threading
from typing import Any, Dict, List, Union
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
class Singleton:
"""A thread-safe singleton class that can be inherited from."""
_instance = None
_lock = threading.Lock()
def __new__(cls) -> Any:
"""Create a new shared instance of the class."""
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
class SharedCallbackManager(Singleton, BaseCallbackManager):
"""A thread-safe singleton CallbackManager."""
_callback_manager: CallbackManager = CallbackManager(handlers=[])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
with self._lock:
self._callback_manager.on_llm_end(response, **kwargs)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
with self._lock:
self._callback_manager.on_llm_new_token(token, **kwargs)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
with self._lock:
self._callback_manager.on_llm_error(error, **kwargs)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
with self._lock:
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
with self._lock:
self._callback_manager.on_chain_end(outputs, **kwargs)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
with self._lock:
self._callback_manager.on_chain_error(error, **kwargs)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
with self._lock:
self._callback_manager.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
with self._lock:
self._callback_manager.on_tool_end(output, **kwargs)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
with self._lock:
self._callback_manager.on_tool_error(error, **kwargs)
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
with self._lock:
self._callback_manager.on_text(text, **kwargs)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
with self._lock:
self._callback_manager.on_agent_finish(finish, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:
self._callback_manager.add_handler(callback)
def remove_handler(self, callback: BaseCallbackHandler) -> None:
"""Remove a callback from the callback manager."""
with self._lock:
self._callback_manager.remove_handler(callback)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
with self._lock:
self._callback_manager.handlers = handlers

View File

@@ -1,12 +1,5 @@
"""Tracers that record execution of LangChain runs."""
from langchain.callbacks.tracers.base import SharedTracer, Tracer
from langchain.callbacks.tracers.langchain import BaseLangChainTracer
from langchain.callbacks.tracers.langchain import LangChainTracer
class SharedLangChainTracer(SharedTracer, BaseLangChainTracer):
"""Shared tracer that records LangChain execution to LangChain endpoint."""
class LangChainTracer(Tracer, BaseLangChainTracer):
"""Tracer that records LangChain execution to LangChain endpoint."""
__all__ = ["LangChainTracer"]

View File

@@ -1,14 +1,12 @@
"""Base interfaces for tracing runs."""
from __future__ import annotations
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.shared import Singleton
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
@@ -16,7 +14,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSession,
TracerSessionCreate,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import LLMResult
class TracerException(Exception):
@@ -26,13 +24,25 @@ class TracerException(Exception):
class BaseTracer(BaseCallbackHandler, ABC):
"""Base interface for tracers."""
@abstractmethod
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[TracerSession] = None
@staticmethod
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
if isinstance(child_run, LLMRun):
parent_run.child_llm_runs.append(child_run)
elif isinstance(child_run, ChainRun):
parent_run.child_chain_runs.append(child_run)
elif isinstance(child_run, ToolRun):
parent_run.child_tool_runs.append(child_run)
else:
raise TracerException(f"Invalid run type: {type(child_run)}")
@abstractmethod
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
@@ -42,15 +52,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
@abstractmethod
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
"""NOT thread safe, do not call this method from multiple threads."""
session_create = TracerSessionCreate(name=name, extra=kwargs)
session = self._persist_session(session_create)
self._session = session
self.session = session
return session
@abstractmethod
@@ -61,283 +67,232 @@ class BaseTracer(BaseCallbackHandler, ABC):
def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session."""
@property
@abstractmethod
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
@property
@abstractmethod
def _execution_order(self) -> int:
"""Get the execution order for a run."""
@_execution_order.setter
@abstractmethod
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
@property
@abstractmethod
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
@_session.setter
@abstractmethod
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Start a trace for a run."""
self._execution_order += 1
if self._stack:
if not (
isinstance(self._stack[-1], ChainRun)
or isinstance(self._stack[-1], ToolRun)
):
if run.parent_uuid:
parent_run = self.run_map[run.parent_uuid]
if parent_run:
if isinstance(parent_run, LLMRun):
raise TracerException(
"Cannot add child run to an LLM run. "
"LLM runs are not allowed to have children."
)
self._add_child_run(parent_run, run)
else:
raise TracerException(
f"Nested {run.__class__.__name__} can only be"
f" logged inside a ChainRun or ToolRun"
f"Parent run with UUID {run.parent_uuid} not found."
)
self._add_child_run(self._stack[-1], run)
self._stack.append(run)
def _end_trace(self) -> None:
self.run_map[run.uuid] = run
def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""End a trace for a run."""
run = self._stack.pop()
if not self._stack:
self._execution_order = 1
if not run.parent_uuid:
self._persist_run(run)
else:
parent_run = self.run_map.get(run.parent_uuid)
if parent_run is None:
raise TracerException(
f"Parent run with UUID {run.parent_uuid} not found."
)
if isinstance(parent_run, LLMRun):
raise TracerException("LLM Runs are not allowed to have children. ")
if run.child_execution_order > parent_run.child_execution_order:
parent_run.child_execution_order = run.child_execution_order
self.run_map.pop(run.uuid)
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
"""Get the execution order for a run."""
if parent_run_id is None:
return 1
parent_run = self.run_map.get(parent_run_id)
if parent_run is None:
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
if isinstance(parent_run, LLMRun):
raise TracerException("LLM Runs are not allowed to have children. ")
return parent_run.child_execution_order + 1
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
if self.session is None:
self.session = self.load_default_session()
if run_id is None:
run_id = str(uuid4())
execution_order = self._get_execution_order(parent_run_id)
llm_run = LLMRun(
uuid=run_id,
parent_uuid=parent_run_id,
serialized=serialized,
prompts=prompts,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
session_id=self._session.id,
id=self._generate_id(),
execution_order=execution_order,
child_execution_order=execution_order,
session_id=self.session.id,
)
self._start_trace(llm_run)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Handle a new token for an LLM run."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(self, response: LLMResult, *, run_id: str, **kwargs: Any) -> None:
"""End a trace for an LLM run."""
if not self._stack or not isinstance(self._stack[-1], LLMRun):
if not run_id:
raise TracerException("No run_id provided for on_llm_end callback.")
llm_run = self.run_map.get(run_id)
if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].response = response
self._end_trace()
llm_run.response = response
llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
**kwargs: Any,
) -> None:
"""Handle an error for an LLM run."""
if not self._stack or not isinstance(self._stack[-1], LLMRun):
if not run_id:
raise TracerException("No run_id provided for on_llm_error callback.")
llm_run = self.run_map.get(run_id)
if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced")
self._stack[-1].error = repr(error)
self._stack[-1].end_time = datetime.utcnow()
self._end_trace()
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a chain run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
if self.session is None:
self.session = self.load_default_session()
execution_order = self._get_execution_order(parent_run_id)
chain_run = ChainRun(
uuid=run_id,
parent_uuid=parent_run_id,
serialized=serialized,
inputs=inputs,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
session_id=self._session.id,
id=self._generate_id(),
session_id=self.session.id,
)
self._start_trace(chain_run)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(
self, outputs: Dict[str, Any], *, run_id: str, **kwargs: Any
) -> None:
"""End a trace for a chain run."""
if not self._stack or not isinstance(self._stack[-1], ChainRun):
chain_run = self.run_map.get(run_id)
if chain_run is None or not isinstance(chain_run, ChainRun):
raise TracerException("No ChainRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].outputs = outputs
self._end_trace()
chain_run.outputs = outputs
chain_run.end_time = datetime.utcnow()
self._end_trace(chain_run)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
**kwargs: Any,
) -> None:
"""Handle an error for a chain run."""
if not self._stack or not isinstance(self._stack[-1], ChainRun):
chain_run = self.run_map.get(run_id)
if chain_run is None or not isinstance(chain_run, ChainRun):
raise TracerException("No ChainRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].error = repr(error)
self._end_trace()
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
self._end_trace(chain_run)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: str,
parent_run_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a tool run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
if self.session is None:
self.session = self.load_default_session()
execution_order = self._get_execution_order(parent_run_id)
tool_run = ToolRun(
uuid=run_id,
parent_uuid=parent_run_id,
serialized=serialized,
# TODO: this is duplicate info as above, not needed.
action=str(serialized),
tool_input=input_str,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
session_id=self._session.id,
id=self._generate_id(),
session_id=self.session.id,
)
self._start_trace(tool_run)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: str, *, run_id: str, **kwargs: Any) -> None:
"""End a trace for a tool run."""
if not self._stack or not isinstance(self._stack[-1], ToolRun):
tool_run = self.run_map.get(run_id)
if tool_run is None or not isinstance(tool_run, ToolRun):
raise TracerException("No ToolRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].output = output
self._end_trace()
tool_run.output = output
tool_run.end_time = datetime.utcnow()
self._end_trace(tool_run)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: str,
**kwargs: Any,
) -> None:
"""Handle an error for a tool run."""
if not self._stack or not isinstance(self._stack[-1], ToolRun):
tool_run = self.run_map.get(run_id)
if tool_run is None or not isinstance(tool_run, ToolRun):
raise TracerException("No ToolRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].error = repr(error)
tool_run.error = repr(error)
tool_run.end_time = datetime.utcnow()
self._end_trace(tool_run)
self._end_trace()
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
return self
def on_text(self, text: str, **kwargs: Any) -> None:
"""Handle a text message."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Handle an agent finish message."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
class Tracer(BaseTracer, ABC):
"""A non-thread safe implementation of the BaseTracer interface."""
def __init__(self) -> None:
"""Initialize a tracer."""
self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = []
self._tracer_execution_order = 1
self._tracer_session: Optional[TracerSession] = None
@property
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
return self._tracer_stack
@property
def _execution_order(self) -> int:
"""Get the execution order for a run."""
return self._tracer_execution_order
@_execution_order.setter
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
self._tracer_execution_order = value
@property
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
return self._tracer_session
@_session.setter
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
if self._stack:
raise TracerException(
"Cannot set a session while a trace is being recorded"
)
self._tracer_session = value
@dataclass
class TracerStack(threading.local):
"""A stack of runs used for logging."""
stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list)
execution_order: int = 1
class SharedTracer(Singleton, BaseTracer, ABC):
"""A thread-safe Singleton implementation of BaseTracer."""
_tracer_stack = TracerStack()
_tracer_session = None
@property
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
return self._tracer_stack.stack
@property
def _execution_order(self) -> int:
"""Get the execution order for a run."""
return self._tracer_stack.execution_order
@_execution_order.setter
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
self._tracer_stack.execution_order = value
@property
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
return self._tracer_session
@_session.setter
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
with self._lock:
# TODO: currently, we are only checking current thread's stack.
# Need to make sure that we are not in the middle of a trace
# in any thread.
if self._stack:
raise TracerException(
"Cannot set a session while a trace is being recorded"
)
self._tracer_session = value
def __copy__(self) -> BaseTracer:
"""Copy the tracer."""
return self

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import logging
import os
from abc import ABC
from typing import Any, Dict, Optional, Union
import requests
@@ -18,14 +17,17 @@ from langchain.callbacks.tracers.schemas import (
)
class BaseLangChainTracer(BaseTracer, ABC):
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
always_verbose: bool = True
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
self.session = self.load_session(session_name)
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
@@ -59,54 +61,29 @@ class BaseLangChainTracer(BaseTracer, ABC):
session = TracerSession(id=1, **session_create.dict())
return session
def load_session(self, session_name: str) -> TracerSession:
def _load_session(self, session_name: Optional[str] = None) -> TracerSession:
"""Load a session from the tracer."""
try:
r = requests.get(
f"{self._endpoint}/sessions?name={session_name}",
headers=self._headers,
)
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)
tracer_session = TracerSession(**r.json()[0])
self._session = tracer_session
return tracer_session
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load session {session_name}, using empty session: {e}"
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSession(id=1)
self._session = tracer_session
return tracer_session
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> TracerSession:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session."""
try:
r = requests.get(
f"{self._endpoint}/sessions",
headers=self._headers,
)
# Use the first session result
tracer_session = TracerSession(**r.json()[0])
self._session = tracer_session
return tracer_session
except Exception as e:
logging.warning(f"Failed to default session, using empty session: {e}")
tracer_session = TracerSession(id=1)
self._session = tracer_session
return tracer_session
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
if isinstance(child_run, LLMRun):
parent_run.child_llm_runs.append(child_run)
elif isinstance(child_run, ChainRun):
parent_run.child_chain_runs.append(child_run)
else:
parent_run.child_tool_runs.append(child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return None
return self._load_session("default")

View File

@@ -32,11 +32,13 @@ class TracerSession(TracerSessionBase):
class BaseRun(BaseModel):
"""Base class for Run."""
id: Optional[Union[int, str]] = None
uuid: str
parent_uuid: Optional[str] = None
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: Optional[Dict[str, Any]] = None
execution_order: int
child_execution_order: int
serialized: Dict[str, Any]
session_id: int
error: Optional[str] = None
@@ -57,7 +59,6 @@ class ChainRun(BaseRun):
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
class ToolRun(BaseRun):
@@ -69,7 +70,6 @@ class ToolRun(BaseRun):
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
ChainRun.update_forward_refs()

View File

@@ -5,12 +5,16 @@ from typing import Any, Dict, List, Optional
from pydantic import Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
NullCallbackManagerForChainRun,
)
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.requests import TextRequestsWrapper
from langchain.schema import BaseLanguageModel
class APIChain(Chain):
@@ -61,16 +65,18 @@ class APIChain(Chain):
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: CallbackManagerForChainRun = NullCallbackManagerForChainRun(),
) -> Dict[str, str]:
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs
)
self.callback_manager.on_text(
api_url, color="green", end="\n", verbose=self.verbose
question=question, api_docs=self.api_docs, callbacks=run_manager.get_child()
)
run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
api_response = self.requests_wrapper.get(api_url)
self.callback_manager.on_text(
run_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose
)
answer = self.api_answer_chain.predict(

View File

@@ -1,15 +1,24 @@
"""Base interface that all chains should implement."""
import inspect
import json
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
NullCallbackManagerForChainRun,
)
from langchain.schema import BaseMemory
@@ -21,9 +30,8 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
callback_manager: BaseCallbackManager = Field(
default_factory=get_callback_manager, exclude=True
)
callbacks: Callbacks = None
callback_manager: Optional[BaseCallbackManager] = None
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
@@ -37,15 +45,16 @@ class Chain(BaseModel, ABC):
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
@@ -82,15 +91,26 @@ class Chain(BaseModel, ABC):
)
@abstractmethod
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: CallbackManagerForChainRun = NullCallbackManagerForChainRun(),
) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
raise NotImplementedError("Async call not supported for this chain type.")
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@@ -104,21 +124,31 @@ class Chain(BaseModel, ABC):
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = self._call(inputs)
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
run_manager.on_chain_error(e)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
async def acall(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@@ -132,30 +162,24 @@ class Chain(BaseModel, ABC):
"""
inputs = self.prep_inputs(inputs)
if self.callback_manager.is_async:
await self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
else:
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
)
try:
outputs = await self._acall(inputs)
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
else:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
await run_manager.on_chain_error(e)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
else:
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
await run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
def prep_outputs(
@@ -195,11 +219,13 @@ class Chain(BaseModel, ABC):
self._validate_inputs(inputs)
return inputs
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(self, *args: Any, **kwargs: Any) -> str:
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
@@ -210,17 +236,17 @@ class Chain(BaseModel, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0])[self.output_keys[0]]
return self(args[0], callbacks=callbacks)[self.output_keys[0]]
if kwargs and not args:
return self(kwargs)[self.output_keys[0]]
return self(kwargs, callbacks=callbacks)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
async def arun(self, *args: Any, **kwargs: Any) -> str:
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
@@ -231,10 +257,10 @@ class Chain(BaseModel, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (await self.acall(args[0]))[self.output_keys[0]]
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]]
if kwargs and not args:
return (await self.acall(kwargs))[self.output_keys[0]]
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"

View File

@@ -1,13 +1,13 @@
"""Chain for applying constitutional principles to the outputs of another chain."""
from typing import Any, Dict, List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.constitutional_ai.principles import PRINCIPLES
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class ConstitutionalChain(Chain):

View File

@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
@@ -15,7 +16,7 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document
from langchain.schema import BaseMessage, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.

View File

@@ -5,11 +5,17 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
from langchain.schema import LLMResult, PromptValue
class LLMChain(Chain):
@@ -53,21 +59,39 @@ class LLMChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.apply([inputs], run_manager=run_manager)[0]
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
return self.llm.generate_prompt(prompts, stop)
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
)
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
async def agenerate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list)
return await self.llm.agenerate_prompt(prompts, stop)
return await self.llm.agenerate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
)
def prep_prompts(
self, input_list: List[Dict[str, Any]]
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
@@ -79,7 +103,8 @@ class LLMChain(Chain):
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
@@ -88,7 +113,9 @@ class LLMChain(Chain):
return prompts, stop
async def aprep_prompts(
self, input_list: List[Dict[str, Any]]
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
@@ -100,12 +127,8 @@ class LLMChain(Chain):
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if self.callback_manager.is_async:
await self.callback_manager.on_text(
_text, end="\n", verbose=self.verbose
)
else:
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if run_manager:
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
@@ -113,14 +136,22 @@ class LLMChain(Chain):
prompts.append(prompt)
return prompts, stop
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
def apply(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
response = self.generate(input_list, run_manager=run_manager)
return self.create_outputs(response)
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
async def aapply(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
response = await self.agenerate(input_list, run_manager=run_manager)
return self.create_outputs(response)
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
@@ -131,13 +162,18 @@ class LLMChain(Chain):
for generation in response.generations
]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return (await self.aapply([inputs], run_manager=run_manager))[0]
def predict(self, **kwargs: Any) -> str:
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
@@ -148,12 +184,13 @@ class LLMChain(Chain):
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
return self(kwargs, callbacks=callbacks)[self.output_key]
async def apredict(self, **kwargs: Any) -> str:
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
@@ -164,31 +201,35 @@ class LLMChain(Chain):
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs))[self.output_key]
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
async def apredict_and_parse(
self, **kwargs: Any
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call apredict and then parse the results."""
result = await self.apredict(**kwargs)
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]]
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list)
result = self.apply(input_list, run_manager=run_manager)
return self._parse_result(result)
def _parse_result(
@@ -202,10 +243,12 @@ class LLMChain(Chain):
return result
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]]
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = await self.aapply(input_list)
result = await self.aapply(input_list, run_manager=run_manager)
return self._parse_result(result)
@property

View File

@@ -3,11 +3,11 @@ from typing import Dict, List
from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.utilities.bash import BashProcess

View File

@@ -1,16 +1,20 @@
"""Chain that interprets a prompt and executes python code to do math."""
import math
import re
from typing import Dict, List
from typing import Dict, List, Optional
import numexpr
from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LLMMathChain(Chain):
@@ -68,15 +72,19 @@ class LLMMathChain(Chain):
# Remove any leading and trailing brackets from the output
return re.sub(r"^\[|\]$", "", output)
def _process_llm_result(self, llm_output: str) -> Dict[str, str]:
self.callback_manager.on_text(llm_output, color="green", verbose=self.verbose)
def _process_llm_result(
self, llm_output: str, run_manager: Optional[CallbackManagerForChainRun] = None
) -> Dict[str, str]:
if run_manager:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
if run_manager:
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
@@ -86,30 +94,21 @@ class LLMMathChain(Chain):
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]:
if self.callback_manager.is_async:
await self.callback_manager.on_text(
llm_output, color="green", verbose=self.verbose
)
else:
self.callback_manager.on_text(
llm_output, color="green", verbose=self.verbose
)
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
if run_manager:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
if self.callback_manager.is_async:
await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
await self.callback_manager.on_text(
output, color="yellow", verbose=self.verbose
)
else:
await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
await self.callback_manager.on_text(
output, color="yellow", verbose=self.verbose
)
if run_manager:
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
@@ -119,30 +118,35 @@ class LLMMathChain(Chain):
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
)
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
if run_manager:
run_manager.on_text(inputs[self.input_key])
llm_output = llm_executor.predict(
question=inputs[self.input_key], stop=["```output"]
question=inputs[self.input_key],
stop=["```output"],
callbacks=run_manager.get_child() if run_manager else None,
)
return self._process_llm_result(llm_output)
return self._process_llm_result(llm_output, run_manager=run_manager)
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
)
if self.callback_manager.is_async:
await self.callback_manager.on_text(
inputs[self.input_key], verbose=self.verbose
)
else:
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
if run_manager:
await run_manager.on_text(inputs[self.input_key])
llm_output = await llm_executor.apredict(
question=inputs[self.input_key], stop=["```output"]
question=inputs[self.input_key],
stop=["```output"],
callbacks=run_manager.get_child() if run_manager else None,
)
return await self._aprocess_llm_result(llm_output)
return await self._aprocess_llm_result(llm_output, run_manager=run_manager)
@property
def _chain_type(self) -> str:

View File

@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.utilities import PythonREPL

View File

@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC):

View File

@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -21,7 +22,6 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
)
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BaseQAWithSourcesChain(Chain, ABC):

View File

@@ -1,6 +1,7 @@
"""Load question answering with sources chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@@ -14,7 +15,6 @@ from langchain.chains.qa_with_sources import (
)
from langchain.chains.question_answering import map_rerank_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@@ -1,6 +1,7 @@
"""Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
stuff_prompt,
)
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
@@ -14,7 +15,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore

View File

@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.sql_database import SQLDatabase

View File

@@ -1,6 +1,7 @@
"""Load summarizing chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
@@ -8,7 +9,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@@ -2,6 +2,10 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
@@ -85,7 +89,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
) # trim off the trailing ' ' that might come from the "Assistant: "
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
@@ -98,10 +105,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
if run_manager:
run_manager.on_llm_new_token(
delta,
)
else:
response = self.client.completion(**params)
completion = response["completion"]
@@ -109,7 +116,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
@@ -122,15 +132,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
async for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
if run_manager:
await run_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
response = await self.client.acompletion(**params)

View File

@@ -1,21 +1,30 @@
import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Dict, List, Optional
from pydantic import Extra, Field, validator
from pydantic import Extra, Field, root_validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import (
AIMessage,
BaseLanguageModel,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
PromptValue,
get_buffer_string,
)
@@ -26,7 +35,19 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
callbacks: Callbacks = None
callback_manager: Optional[BaseCallbackManager] = None
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
"""Configuration for this pydantic object."""
@@ -34,98 +55,130 @@ class BaseChatModel(BaseLanguageModel, ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
results = [self._generate(m, stop=stop) for m in messages]
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
return LLMResult(generations=generations, llm_output=llm_output)
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
return output
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
results = await asyncio.gather(
*[self._agenerate(m, stop=stop) for m in messages]
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
return LLMResult(generations=generations, llm_output=llm_output)
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
return output
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
prompt_strings = [p.to_string() for p in prompts]
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
)
try:
output = self.generate(prompt_messages, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
prompt_strings = [p.to_string() for p in prompts]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
)
try:
output = await self.agenerate(prompt_messages, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
@abstractmethod
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
@abstractmethod
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
def __call__(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> BaseMessage:
return self._generate(messages, stop=stop).generations[0].message
generation = self.generate(
[messages], stop=stop, callbacks=callbacks
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
result = self([HumanMessage(content=message)], stop=stop)
@@ -134,15 +187,21 @@ class BaseChatModel(BaseLanguageModel, ABC):
class SimpleChatModel(BaseChatModel):
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
output_str = self._call(messages, stop=stop)
output_str = self._call(messages, stop=stop, run_manager=run_manager)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Simpler interface."""

View File

@@ -14,6 +14,10 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
AIMessage,
@@ -242,7 +246,10 @@ class ChatOpenAI(BaseChatModel):
return {"token_usage": overall_token_usage, "model_name": self.model_name}
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
@@ -255,10 +262,8 @@ class ChatOpenAI(BaseChatModel):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{"content": inner_completion, "role": role}
)
@@ -287,7 +292,10 @@ class ChatOpenAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
@@ -300,16 +308,8 @@ class ChatOpenAI(BaseChatModel):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{"content": inner_completion, "role": role}
)

View File

@@ -2,6 +2,10 @@
import datetime
from typing import List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, ChatResult
@@ -33,13 +37,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
return_pl_id: Optional[bool] = False
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
request_start_time = datetime.datetime.now().timestamp()
generated_responses = super()._generate(messages, stop)
generated_responses = super()._generate(messages, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
message_dicts, params = super()._create_message_dicts(messages, stop)
for i, generation in enumerate(generated_responses.generations):
@@ -67,13 +74,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
return generated_responses
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
from promptlayer.utils import get_api_key, promptlayer_api_request_async
request_start_time = datetime.datetime.now().timestamp()
generated_responses = await super()._agenerate(messages, stop)
generated_responses = await super()._agenerate(messages, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
message_dicts, params = super()._create_message_dicts(messages, stop)
for i, generation in enumerate(generated_responses.generations):

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.experimental.autonomous_agents.baby_agi.task_creation import (
TaskCreationChain,
@@ -13,7 +14,6 @@ from langchain.experimental.autonomous_agents.baby_agi.task_execution import (
from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import (
TaskPrioritizationChain,
)
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore

View File

@@ -1,5 +1,5 @@
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language import BaseLanguageModel
class TaskCreationChain(LLMChain):

View File

@@ -1,5 +1,5 @@
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language import BaseLanguageModel
class TaskExecutionChain(LLMChain):

View File

@@ -1,5 +1,5 @@
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language import BaseLanguageModel
class TaskPrioritizationChain(LLMChain):

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
@@ -106,7 +107,12 @@ class AI21(LLM):
"""Return type of llm."""
return "ai21"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to AI21's complete endpoint.
Args:

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -200,7 +201,12 @@ class AlephAlpha(LLM):
"""Return type of llm."""
return "alpeh_alpha"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Aleph Alpha's completion endpoint.
Args:

View File

@@ -4,6 +4,10 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
@@ -142,7 +146,12 @@ class Anthropic(LLM, _AnthropicCommon):
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
r"""Call out to Anthropic's completion endpoint.
Args:
@@ -171,9 +180,8 @@ class Anthropic(LLM, _AnthropicCommon):
for data in stream_resp:
delta = data["completion"][len(current_completion) :]
current_completion = data["completion"]
self.callback_manager.on_llm_new_token(
delta, verbose=self.verbose, **data
)
if run_manager:
run_manager.on_llm_new_token(delta, **data)
return current_completion
response = self.client.completion(
prompt=self._wrap_prompt(prompt),
@@ -182,7 +190,12 @@ class Anthropic(LLM, _AnthropicCommon):
)
return response["completion"]
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Anthropic's completion endpoint asynchronously."""
stop = self._get_anthropic_stop(stop)
if self.streaming:
@@ -195,14 +208,8 @@ class Anthropic(LLM, _AnthropicCommon):
async for data in stream_resp:
delta = data["completion"][len(current_completion) :]
current_completion = data["completion"]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
delta, verbose=self.verbose, **data
)
else:
self.callback_manager.on_llm_new_token(
delta, verbose=self.verbose, **data
)
if run_manager:
await run_manager.on_llm_new_token(delta, **data)
return current_completion
response = await self.client.acompletion(
prompt=self._wrap_prompt(prompt),

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -80,7 +81,12 @@ class Banana(LLM):
"""Return type of llm."""
return "banana"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call to Banana endpoint."""
try:
import banana_dev as banana

View File

@@ -1,16 +1,25 @@
"""Base interface for large language models to expose."""
import inspect
import json
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import yaml
from pydantic import Extra, Field, validator
from pydantic import Extra, Field, root_validator, validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import Generation, LLMResult, PromptValue
def _get_verbosity() -> bool:
@@ -59,7 +68,8 @@ class BaseLLM(BaseLanguageModel, ABC):
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
callbacks: Callbacks = None
callback_manager: Optional[BaseCallbackManager] = None
class Config:
"""Configuration for this pydantic object."""
@@ -67,15 +77,16 @@ class BaseLLM(BaseLanguageModel, ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
@@ -90,30 +101,45 @@ class BaseLLM(BaseLanguageModel, ABC):
@abstractmethod
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Run the LLM on the given prompts."""
@abstractmethod
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Run the LLM on the given prompts."""
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop)
return self.generate(prompt_strings, stop=stop, callbacks=callbacks)
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(prompt_strings, stop=stop)
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
@@ -124,21 +150,31 @@ class BaseLLM(BaseLanguageModel, ABC):
f" argument of type {type(prompts)}."
)
disregard_cache = self.cache is not None and not self.cache
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts
)
try:
output = self._generate(prompts, stop=stop)
output = (
self._generate(prompts, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
run_manager.on_llm_error(e)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
run_manager.on_llm_end(output)
return output
params = self.dict()
params["stop"] = stop
@@ -149,15 +185,19 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
run_manager = self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts
)
try:
new_results = self._generate(missing_prompts, stop=stop)
new_results = (
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(missing_prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
run_manager.on_llm_error(e)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
run_manager.on_llm_end(new_results)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
@@ -167,36 +207,38 @@ class BaseLLM(BaseLanguageModel, ABC):
return LLMResult(generations=generations, llm_output=llm_output)
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
disregard_cache = self.cache is not None and not self.cache
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts
)
try:
output = await self._agenerate(prompts, stop=stop)
output = (
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
await run_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(output, verbose=self.verbose)
await run_manager.on_llm_end(output, verbose=self.verbose)
return output
params = self.dict()
params["stop"] = stop
@@ -207,32 +249,22 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
verbose=self.verbose,
)
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
)
try:
new_results = await self._agenerate(missing_prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(
new_results, verbose=self.verbose
new_results = (
await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager
)
if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop)
)
else:
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
await run_manager.on_llm_end(new_results)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
@@ -241,9 +273,15 @@ class BaseLLM(BaseLanguageModel, ABC):
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def __call__(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
return self.generate([prompt], stop=stop).generations[0][0].text
return (
self.generate([prompt], stop=stop, callbacks=callbacks)
.generations[0][0]
.text
)
@property
def _identifying_params(self) -> Mapping[str, Any]:
@@ -307,30 +345,56 @@ class LLM(BaseLLM):
"""
@abstractmethod
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Run the LLM on the given prompt and input."""
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> str:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# TODO: add caching here.
generations = []
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
for prompt in prompts:
text = self._call(prompt, stop=stop)
text = (
self._call(prompt, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._call(prompt, stop=stop)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts:
text = await self._acall(prompt, stop=stop)
text = (
await self._acall(prompt, stop=stop, run_manager=run_manager)
if new_arg_supported
else await self._acall(prompt, stop=stop)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -81,7 +82,12 @@ class CerebriumAI(LLM):
"""Return type of llm."""
return "cerebriumai"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call to CerebriumAI endpoint."""
try:
from cerebrium import model_api_request

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -100,7 +101,12 @@ class Cohere(LLM):
"""Return type of llm."""
return "cohere"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Cohere's generate endpoint.
Args:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -60,7 +61,12 @@ class DeepInfra(LLM):
"""Return type of llm."""
return "deepinfra"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to DeepInfra's inference API endpoint.
Args:

View File

@@ -1,6 +1,7 @@
"""Fake LLM wrapper for testing purposes."""
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@@ -15,7 +16,12 @@ class FakeListLLM(LLM):
"""Return type of llm."""
return "fake-list"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
response = self.responses[self.i]
self.i += 1

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -81,7 +82,12 @@ class ForefrontAI(LLM):
"""Return type of llm."""
return "forefrontai"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to ForefrontAI's complete endpoint.
Args:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
@@ -130,7 +131,12 @@ class GooseAI(LLM):
"""Return type of llm."""
return "gooseai"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call the GooseAI API."""
params = self._default_params
if stop is not None:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
@@ -159,7 +160,12 @@ class GPT4All(LLM):
"""Return the type of llm."""
return "gpt4all"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
r"""Call out to GPT4All's generate method.
Args:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -88,7 +89,12 @@ class HuggingFaceEndpoint(LLM):
"""Return type of llm."""
return "huggingface_endpoint"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -84,7 +85,12 @@ class HuggingFaceHub(LLM):
"""Return type of llm."""
return "huggingface_hub"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:

View File

@@ -5,6 +5,7 @@ from typing import Any, List, Mapping, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
@@ -146,7 +147,12 @@ class HuggingFacePipeline(LLM):
def _llm_type(self) -> str:
return "huggingface_pipeline"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
response = self.pipeline(prompt)
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text.

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
@@ -154,7 +155,12 @@ class LlamaCpp(LLM):
"""Return type of llm."""
return "llama.cpp"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call the Llama model and return the output.
Args:

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@@ -42,7 +43,12 @@ class ManifestWrapper(LLM):
"""Return type of llm."""
return "manifest"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to LLM through Manifest."""
if stop is not None and len(stop) != 1:
raise NotImplementedError(

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
@@ -69,7 +70,12 @@ class Modal(LLM):
"""Return type of llm."""
return "modal"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call to Modal endpoint."""
params = self.model_kwargs or {}
response = requests.post(

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
@@ -111,7 +112,12 @@ class NLPCloud(LLM):
"""Return type of llm."""
return "nlpcloud"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to NLPCloud's create endpoint.
Args:

View File

@@ -29,6 +29,10 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
@@ -253,7 +257,10 @@ class BaseOpenAI(BaseLLM):
return {**normal_params, **self.model_kwargs}
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts.
@@ -286,11 +293,12 @@ class BaseOpenAI(BaseLLM):
for stream_resp in completion_with_retry(
self, prompt=_prompts, **params
):
self.callback_manager.on_llm_new_token(
stream_resp["choices"][0]["text"],
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
)
if run_manager:
run_manager.on_llm_new_token(
stream_resp["choices"][0]["text"],
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
)
_update_response(response, stream_resp)
choices.extend(response["choices"])
else:
@@ -302,7 +310,10 @@ class BaseOpenAI(BaseLLM):
return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params
@@ -321,14 +332,8 @@ class BaseOpenAI(BaseLLM):
async for stream_resp in await acompletion_with_retry(
self, prompt=_prompts, **params
):
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
stream_resp["choices"][0]["text"],
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
)
else:
self.callback_manager.on_llm_new_token(
if run_manager:
await run_manager.on_llm_new_token(
stream_resp["choices"][0]["text"],
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
@@ -704,7 +709,10 @@ class OpenAIChat(BaseLLM):
return messages, params
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
if self.streaming:
@@ -713,10 +721,10 @@ class OpenAIChat(BaseLLM):
for stream_resp in completion_with_retry(self, messages=messages, **params):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
if run_manager:
run_manager.on_llm_new_token(
token,
)
return LLMResult(
generations=[[Generation(text=response)]],
)
@@ -734,7 +742,10 @@ class OpenAIChat(BaseLLM):
)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
if self.streaming:
@@ -745,15 +756,9 @@ class OpenAIChat(BaseLLM):
):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
if run_manager:
await run_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
return LLMResult(
generations=[[Generation(text=response)]],

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -130,7 +131,12 @@ class Petals(LLM):
"""Return type of llm."""
return "petals"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call the Petals API."""
params = self._default_params
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]

View File

@@ -2,6 +2,10 @@
import datetime
from typing import List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms import OpenAI, OpenAIChat
from langchain.schema import LLMResult
@@ -33,13 +37,16 @@ class PromptLayerOpenAI(OpenAI):
return_pl_id: Optional[bool] = False
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
request_start_time = datetime.datetime.now().timestamp()
generated_responses = super()._generate(prompts, stop)
generated_responses = super()._generate(prompts, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
for i in range(len(prompts)):
prompt = prompts[i]
@@ -69,12 +76,15 @@ class PromptLayerOpenAI(OpenAI):
return generated_responses
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async
request_start_time = datetime.datetime.now().timestamp()
generated_responses = await super()._agenerate(prompts, stop)
generated_responses = await super()._agenerate(prompts, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
for i in range(len(prompts)):
prompt = prompts[i]
@@ -131,13 +141,16 @@ class PromptLayerOpenAIChat(OpenAIChat):
return_pl_id: Optional[bool] = False
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
request_start_time = datetime.datetime.now().timestamp()
generated_responses = super()._generate(prompts, stop)
generated_responses = super()._generate(prompts, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
for i in range(len(prompts)):
prompt = prompts[i]
@@ -167,12 +180,15 @@ class PromptLayerOpenAIChat(OpenAIChat):
return generated_responses
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async
request_start_time = datetime.datetime.now().timestamp()
generated_responses = await super()._agenerate(prompts, stop)
generated_responses = await super()._agenerate(prompts, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
for i in range(len(prompts)):
prompt = prompts[i]

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
@@ -78,7 +79,12 @@ class Replicate(LLM):
"""Return type of model."""
return "replicate"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call to replicate endpoint."""
try:
import replicate as replicate_python

View File

@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
@@ -204,7 +205,12 @@ class RWKV(LLM, BaseModel):
return decoded
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
r"""RWKV generation
Args:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
@@ -194,7 +195,12 @@ class SagemakerEndpoint(LLM):
"""Return type of llm."""
return "sagemaker_endpoint"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Sagemaker inference endpoint.
Args:

View File

@@ -6,6 +6,7 @@ from typing import Any, Callable, List, Mapping, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
@@ -208,5 +209,10 @@ class SelfHostedPipeline(LLM):
def _llm_type(self) -> str:
return "self_hosted_llm"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)

View File

@@ -5,6 +5,7 @@ from typing import Any, Callable, List, Mapping, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.self_hosted import SelfHostedPipeline
from langchain.llms.utils import enforce_stop_tokens
@@ -198,5 +199,10 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
def _llm_type(self) -> str:
return "selfhosted_huggingface_pipeline"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -80,7 +81,12 @@ class StochasticAI(LLM):
"""Return type of llm."""
return "stochasticai"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to StochasticAI's complete endpoint.
Args:

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -117,7 +118,12 @@ class Writer(LLM):
"""Return type of llm."""
return "writer"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Writer's complete endpoint.
Args:

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
@@ -13,7 +14,7 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
from langchain.schema import BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Type, Union
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
@@ -13,7 +14,6 @@ from langchain.memory.prompt import (
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import (
BaseLanguageModel,
BaseMessage,
SystemMessage,
get_buffer_string,

View File

@@ -2,12 +2,12 @@ from typing import Any, Dict, List, Type
from pydantic import BaseModel, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import (
BaseLanguageModel,
BaseMessage,
SystemMessage,
get_buffer_string,

View File

@@ -1,7 +1,8 @@
from typing import Any, Dict, List
from langchain.base_language import BaseLanguageModel
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
from langchain.schema import BaseMessage, get_buffer_string
class ConversationTokenBufferMemory(BaseChatMemory):

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
from typing import TypeVar
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
from langchain.schema import BaseOutputParser, OutputParserException
T = TypeVar("T")

View File

@@ -2,11 +2,11 @@ from __future__ import annotations
from typing import TypeVar
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLanguageModel,
BaseOutputParser,
OutputParserException,
PromptValue,

View File

@@ -171,45 +171,6 @@ class PromptValue(BaseModel, ABC):
"""Return prompt as messages."""
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
class BaseMemory(BaseModel, ABC):
"""Base interface for memory in chains."""

View File

@@ -1,13 +1,21 @@
"""Base implementation for tools or skills."""
import inspect
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, validate_arguments, validator
from pydantic import BaseModel, Extra, Field, root_validator, validate_arguments
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForToolRun,
CallbackManager,
CallbackManagerForToolRun,
Callbacks,
)
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
@@ -28,7 +36,8 @@ class BaseTool(ABC, BaseModel):
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
callbacks: Callbacks = None
callback_manager: Optional[BaseCallbackManager] = None
class Config:
"""Configuration for this pydantic object."""
@@ -60,15 +69,16 @@ class BaseTool(ABC, BaseModel):
if input_args is not None:
input_args.validate(tool_input)
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> str:
@@ -84,6 +94,7 @@ class BaseTool(ABC, BaseModel):
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
**kwargs: Any,
) -> str:
"""Run the tool."""
@@ -92,22 +103,22 @@ class BaseTool(ABC, BaseModel):
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, verbose=verbose_
)
run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
observation = self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
run_manager.on_tool_error(e)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
return observation
async def arun(
@@ -116,6 +127,7 @@ class BaseTool(ABC, BaseModel):
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
@@ -124,42 +136,27 @@ class BaseTool(ABC, BaseModel):
verbose_ = verbose
else:
verbose_ = self.verbose
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, verbose=verbose_
)
run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
observation = await self._arun(*args, run_manager=run_manager, **kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
else:
self.callback_manager.on_tool_error(e, verbose=verbose_)
await run_manager.on_tool_error(e)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
else:
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
await run_manager.on_tool_end(
observation, color=color, name=self.name, **kwargs
)
return observation
def __call__(self, tool_input: str) -> str:
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
"""Make tool callable."""
return self.run(tool_input)
return self.run(tool_input, callbacks=callbacks)

View File

@@ -76,7 +76,7 @@ class SerpAPIWrapper(BaseModel):
)
return values
async def arun(self, query: str) -> str:
async def arun(self, query: str, **kwargs: Any) -> str:
"""Use aiohttp to run query through SerpAPI and parse result."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
@@ -99,7 +99,7 @@ class SerpAPIWrapper(BaseModel):
return self._process_response(res)
def run(self, query: str) -> str:
def run(self, query: str, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result."""
return self._process_response(self.results(query))

View File

@@ -0,0 +1,80 @@
"""Integration tests for the langchain tracer module."""
import asyncio
import os
import time
import pytest
from aiohttp import ClientSession
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.llms import OpenAI
questions = [
"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?",
"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?",
"Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?",
"Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?",
"Who is Beyonce's husband? What is his age raised to the 0.19 power?",
]
def test_tracing_sequential() -> None:
os.environ["LANGCHAIN_TRACING"] = "true"
for q in questions[:3]:
llm = OpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent.run(q)
@pytest.mark.asyncio
async def test_tracing_concurrent() -> None:
os.environ["LANGCHAIN_TRACING"] = "true"
aiosession = ClientSession()
llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
tasks = [agent.arun(q) for q in questions[:3]]
await asyncio.gather(*tasks)
await aiosession.close()
def test_tracing_context_manager() -> None:
llm = OpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
with tracing_enabled() as session:
assert session
agent.run(questions[0]) # this should be traced
agent.run(questions[0]) # this should not be traced
@pytest.mark.asyncio
async def test_tracing_context_manager_async() -> None:
llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
# start a background task
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
with tracing_enabled() as session:
assert session
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
await asyncio.gather(*tasks)
await task

View File

@@ -0,0 +1,37 @@
"""Integration tests for the langchain tracer module."""
import asyncio
import pytest
from langchain import OpenAI
from langchain.callbacks import get_openai_callback
@pytest.mark.asyncio
async def test_openai_callback() -> None:
llm = OpenAI(temperature=0)
with get_openai_callback() as cb:
llm("What is the square root of 4?")
total_tokens = cb.total_tokens
assert total_tokens > 0
with get_openai_callback() as cb:
llm("What is the square root of 4?")
llm("What is the square root of 4?")
assert cb.total_tokens == total_tokens * 2
with get_openai_callback() as cb:
await asyncio.gather(
*[llm.agenerate(["What is the square root of 4?"]) for _ in range(3)]
)
assert cb.total_tokens == total_tokens * 3
task = asyncio.create_task(llm.agenerate(["What is the square root of 4?"]))
with get_openai_callback() as cb:
await llm.agenerate(["What is the square root of 4?"])
await task
assert cb.total_tokens == total_tokens

View File

@@ -3,7 +3,7 @@ from typing import List
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.schema import (
AIMessage,

View File

@@ -3,7 +3,7 @@
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import (
BaseMessage,

View File

@@ -2,7 +2,7 @@
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.schema import (
BaseMessage,

View File

@@ -3,7 +3,7 @@ from typing import Generator
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.llms.anthropic import Anthropic
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

View File

@@ -5,7 +5,7 @@ from typing import Generator
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult

View File

@@ -4,7 +4,7 @@ from typing import Any, List, Mapping, Optional
from langchain.agents import AgentExecutor, AgentType, initialize_agent
from langchain.agents.tools import Tool
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.llms.base import LLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

View File

@@ -1,10 +1,9 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Dict, List, Union
from typing import Any
from pydantic import BaseModel
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseFakeCallbackHandler(BaseModel):
@@ -17,12 +16,72 @@ class BaseFakeCallbackHandler(BaseModel):
ignore_llm_: bool = False
ignore_chain_: bool = False
ignore_agent_: bool = False
always_verbose_: bool = False
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
# add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0
chain_ends: int = 0
llm_starts: int = 0
llm_ends: int = 0
llm_streams: int = 0
tool_starts: int = 0
tool_ends: int = 0
agent_actions: int = 0
agent_ends: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
"""Base fake callback handler mixin for testing."""
def on_llm_start_common(self) -> None:
self.llm_starts += 1
self.starts += 1
def on_llm_end_common(self) -> None:
self.llm_ends += 1
self.ends += 1
def on_llm_error_common(self) -> None:
self.errors += 1
def on_llm_new_token_common(self) -> None:
self.llm_streams += 1
def on_chain_start_common(self) -> None:
self.chain_starts += 1
self.starts += 1
def on_chain_end_common(self) -> None:
self.chain_ends += 1
self.ends += 1
def on_chain_error_common(self) -> None:
self.errors += 1
def on_tool_start_common(self) -> None:
self.tool_starts += 1
self.starts += 1
def on_tool_end_common(self) -> None:
self.tool_ends += 1
self.ends += 1
def on_tool_error_common(self) -> None:
self.errors += 1
def on_agent_action_common(self) -> None:
self.agent_actions += 1
self.starts += 1
def on_agent_finish_common(self) -> None:
self.agent_ends += 1
self.ends += 1
def on_text_common(self) -> None:
self.text += 1
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Fake callback handler for testing."""
@property
def ignore_llm(self) -> bool:
@@ -39,164 +98,209 @@ class BaseFakeCallbackHandler(BaseModel):
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
# add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0
chain_ends: int = 0
llm_starts: int = 0
llm_ends: int = 0
llm_streams: int = 0
tool_starts: int = 0
tool_ends: int = 0
agent_ends: int = 0
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
"""Fake callback handler for testing."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.llm_starts += 1
self.starts += 1
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_start_common()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.llm_streams += 1
def on_llm_new_token(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_new_token_common()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.llm_ends += 1
self.ends += 1
def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_end_common()
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.errors += 1
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_error_common()
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.chain_starts += 1
self.starts += 1
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_start_common()
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.chain_ends += 1
self.ends += 1
def on_chain_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_end_common()
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
self.errors += 1
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_error_common()
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
self.starts += 1
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_start_common()
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.tool_ends += 1
self.ends += 1
def on_tool_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_end_common()
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
self.errors += 1
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_error_common()
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.text += 1
def on_agent_action(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_agent_action_common()
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1
def on_agent_finish(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_agent_finish_common()
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1
def on_text(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_text_common()
def __deepcopy__(self, memo):
return self
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Fake async callback handler for testing."""
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
self.llm_starts += 1
self.starts += 1
self.on_llm_start_common()
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.llm_streams += 1
async def on_llm_new_token(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_llm_new_token_common()
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.llm_ends += 1
self.ends += 1
async def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_llm_end_common()
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
self.errors += 1
self.on_llm_error_common()
async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
self.chain_starts += 1
self.starts += 1
self.on_chain_start_common()
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.chain_ends += 1
self.ends += 1
async def on_chain_end(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_chain_end_common()
async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
self.errors += 1
self.on_chain_error_common()
async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
self.starts += 1
self.on_tool_start_common()
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.tool_ends += 1
self.ends += 1
async def on_tool_end(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_tool_end_common()
async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
self.errors += 1
self.on_tool_error_common()
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.text += 1
async def on_agent_action(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_agent_action_common()
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1
async def on_agent_finish(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_agent_finish_common()
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1
async def on_text(
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_text_common()
def __deepcopy__(self, memo):
return self

View File

@@ -3,13 +3,9 @@ from typing import Tuple
import pytest
from langchain.callbacks.base import (
AsyncCallbackManager,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentFinish, LLMResult
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
@@ -18,19 +14,26 @@ from tests.unit_tests.callbacks.fake_callback_handler import (
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, "")
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_manager.on_llm_new_token("foo")
run_manager.on_text("foo")
run_manager = manager.on_chain_start({"name": "foo"}, {})
run_manager.on_chain_end({})
run_manager.on_chain_error(Exception())
run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool=""))
run_manager.on_agent_finish(AgentFinish(log="", return_values={}))
run_manager.on_text("foo")
run_manager = manager.on_tool_start({}, "")
run_manager.on_tool_end("")
run_manager.on_tool_error(Exception())
run_manager.on_text("foo")
_check_num_calls(handlers)
@@ -38,75 +41,60 @@ async def _test_callback_manager_async(
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
await manager.on_llm_start({}, [])
await manager.on_llm_end(LLMResult(generations=[]))
await manager.on_llm_error(Exception())
await manager.on_chain_start({"name": "foo"}, {})
await manager.on_chain_end({})
await manager.on_chain_error(Exception())
await manager.on_tool_start({}, "")
await manager.on_tool_end("")
await manager.on_tool_error(Exception())
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
run_manager = await manager.on_llm_start({}, [])
await run_manager.on_llm_end(LLMResult(generations=[]))
await run_manager.on_llm_error(Exception())
await run_manager.on_llm_new_token("foo")
await run_manager.on_text("foo")
run_manager = await manager.on_chain_start({"name": "foo"}, {})
await run_manager.on_chain_end({})
await run_manager.on_chain_error(Exception())
await run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool=""))
await run_manager.on_agent_finish(AgentFinish(log="", return_values={}))
await run_manager.on_text("foo")
run_manager = await manager.on_tool_start({}, "")
await run_manager.on_tool_end("")
await run_manager.on_tool_error(Exception())
await run_manager.on_text("foo")
_check_num_calls(handlers)
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
for handler in handlers:
if handler.always_verbose:
assert handler.starts == 3
assert handler.ends == 4
assert handler.errors == 3
else:
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0
def _test_callback_manager_pass_in_verbose(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [], verbose=True)
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
manager.on_llm_error(Exception(), verbose=True)
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
manager.on_chain_error(Exception(), verbose=True)
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
for handler in handlers:
assert handler.starts == 3
assert handler.starts == 4
assert handler.ends == 4
assert handler.errors == 3
assert handler.text == 3
assert handler.llm_starts == 1
assert handler.llm_ends == 1
assert handler.llm_streams == 1
assert handler.chain_starts == 1
assert handler.chain_ends == 1
assert handler.tool_starts == 1
assert handler.tool_ends == 1
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=False)
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_callback_manager_pass_in_verbose() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2])
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
def test_ignore_llm() -> None:
"""Test ignore llm param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
handler1 = FakeCallbackHandler(ignore_llm_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_llm_start({}, [], verbose=True)
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
manager.on_llm_error(Exception(), verbose=True)
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0
@@ -117,12 +105,12 @@ def test_ignore_llm() -> None:
def test_ignore_chain() -> None:
"""Test ignore chain param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
handler1 = FakeCallbackHandler(ignore_chain_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
manager.on_chain_error(Exception(), verbose=True)
run_manager = manager.on_chain_start({"name": "foo"}, {})
run_manager.on_chain_end({})
run_manager.on_chain_error(Exception())
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0
@@ -133,39 +121,24 @@ def test_ignore_chain() -> None:
def test_ignore_agent() -> None:
"""Test ignore agent param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
handler1 = FakeCallbackHandler(ignore_agent_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
run_manager = manager.on_tool_start({}, "")
run_manager.on_tool_end("")
run_manager.on_tool_error(Exception())
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0
assert handler2.starts == 1
assert handler2.ends == 2
assert handler2.ends == 1
assert handler2.errors == 1
def test_shared_callback_manager() -> None:
"""Test the SharedCallbackManager."""
manager1 = SharedCallbackManager()
manager2 = SharedCallbackManager()
assert manager1 is manager2
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)
@pytest.mark.asyncio
async def test_async_callback_manager() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
handler1 = FakeAsyncCallbackHandler()
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)
@@ -174,8 +147,95 @@ async def test_async_callback_manager() -> None:
@pytest.mark.asyncio
async def test_async_callback_manager_sync_handler() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler1 = FakeCallbackHandler()
handler2 = FakeAsyncCallbackHandler()
handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
handler3 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2, handler3])
await _test_callback_manager_async(manager, handler1, handler2, handler3)
def test_callback_manager_inheritance() -> None:
handler1, handler2, handler3, handler4 = (
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
)
callback_manager1 = CallbackManager([handler1, handler2])
assert callback_manager1.handlers == [handler1, handler2]
assert callback_manager1.inheritable_handlers == []
callback_manager2 = CallbackManager([])
assert callback_manager2.handlers == []
assert callback_manager2.inheritable_handlers == []
callback_manager2.set_handlers([handler1, handler2])
assert callback_manager2.handlers == [handler1, handler2]
assert callback_manager2.inheritable_handlers == [handler1, handler2]
callback_manager2.set_handlers([handler3, handler4], inherit=False)
assert callback_manager2.handlers == [handler3, handler4]
assert callback_manager2.inheritable_handlers == []
callback_manager2.add_handler(handler1)
assert callback_manager2.handlers == [handler3, handler4, handler1]
assert callback_manager2.inheritable_handlers == [handler1]
callback_manager2.add_handler(handler2, inherit=False)
assert callback_manager2.handlers == [handler3, handler4, handler1, handler2]
assert callback_manager2.inheritable_handlers == [handler1]
run_manager = callback_manager2.on_chain_start({"name": "foo"}, {})
child_manager = run_manager.get_child()
assert child_manager.handlers == [handler1]
assert child_manager.inheritable_handlers == [handler1]
child_manager = child_manager.on_tool_start({}, "")
assert child_manager.handlers == [handler1]
assert child_manager.inheritable_handlers == [handler1]
child_manager = child_manager.get_child()
assert child_manager.handlers == [handler1]
assert child_manager.inheritable_handlers == [handler1]
def test_callback_manager_configure() -> None:
"""Test callback manager configuration."""
handler1, handler2, handler3, handler4 = (
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
)
inheritable_callbacks = [handler1, handler2]
local_callbacks = [handler3, handler4]
configured_manager = CallbackManager.configure(
inheritable_callbacks=inheritable_callbacks,
local_callbacks=local_callbacks,
verbose=True,
)
assert len(configured_manager.handlers) == 5
assert len(configured_manager.inheritable_handlers) == 2
assert configured_manager.inheritable_handlers == inheritable_callbacks
assert configured_manager.handlers[:4] == inheritable_callbacks + local_callbacks
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
assert isinstance(configured_manager, CallbackManager)
async_local_callbacks = AsyncCallbackManager([handler3, handler4])
async_configured_manager = AsyncCallbackManager.configure(
inheritable_callbacks=inheritable_callbacks,
local_callbacks=async_local_callbacks,
verbose=False,
)
assert len(async_configured_manager.handlers) == 4
assert len(async_configured_manager.inheritable_handlers) == 2
assert async_configured_manager.inheritable_handlers == inheritable_callbacks
assert async_configured_manager.handlers == inheritable_callbacks + [
handler3,
handler4,
]
assert isinstance(async_configured_manager, AsyncCallbackManager)

View File

@@ -1,9 +1,9 @@
"""Test Tracer classes."""
from __future__ import annotations
import threading
from datetime import datetime
from typing import List, Optional, Union
from uuid import uuid4
import pytest
from freezegun import freeze_time
@@ -12,9 +12,7 @@ from langchain.callbacks.tracers.base import (
BaseTracer,
ChainRun,
LLMRun,
SharedTracer,
ToolRun,
Tracer,
TracerException,
TracerSession,
)
@@ -27,37 +25,45 @@ TEST_SESSION_ID = 2023
@freeze_time("2023-01-01")
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
return ChainRun(
id=None,
uuid="chain_uuid",
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=4,
serialized={},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
child_runs=[
child_chain_runs=[],
child_tool_runs=[
ToolRun(
id=None,
uuid="tool_uuid",
parent_uuid="chain_uuid",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=3,
serialized={},
tool_input="test",
output="test",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
child_runs=[
child_chain_runs=[],
child_tool_runs=[],
child_llm_runs=[
LLMRun(
id=None,
uuid="llm_uuid1",
parent_uuid="tool_uuid",
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
@@ -65,13 +71,17 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
)
],
),
],
child_llm_runs=[
LLMRun(
id=None,
uuid="llm_uuid2",
parent_uuid="chain_uuid",
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=4,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
@@ -83,27 +93,25 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_end("test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_chain_end(outputs={})
chain_uuid = "chain_uuid"
tool_uuid = "tool_uuid"
llm_uuid1 = "llm_uuid1"
llm_uuid2 = "llm_uuid2"
def _add_child_run(
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
parent_run.child_runs.append(child_run)
def _generate_id() -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return None
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
tracer.on_tool_start(
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
def load_session(session_name: str) -> TracerSession:
@@ -121,7 +129,7 @@ def load_default_session() -> TracerSession:
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
class FakeTracer(Tracer):
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
@@ -133,58 +141,6 @@ class FakeTracer(Tracer):
"""Persist a run."""
self.runs.append(run)
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
_add_child_run(parent_run, child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return _generate_id()
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
def load_session(self, session_name: str) -> TracerSession:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load a tracing session."""
return load_default_session()
class FakeSharedTracer(SharedTracer):
"""Fake shared tracer that records LangChain execution."""
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
with self._lock:
self.runs.append(run)
def remove_runs(self) -> None:
"""Remove all runs."""
with self._lock:
self.runs = []
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
_add_child_run(parent_run, child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return _generate_id()
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
@@ -201,12 +157,15 @@ class FakeSharedTracer(SharedTracer):
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = str(uuid4())
compare_run = LLMRun(
id=None,
uuid=uuid,
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
@@ -216,20 +175,11 @@ def test_tracer_llm_run() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_session() -> None:
"""Test tracer on an LLM run without a session."""
tracer = FakeTracer()
with pytest.raises(TracerException):
tracer.on_llm_start(serialized={}, prompts=[])
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
@@ -237,18 +187,21 @@ def test_tracer_llm_run_errors_no_start() -> None:
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=str(uuid4()))
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
uuid = str(uuid4())
compare_run = LLMRun(
id=None,
uuid=uuid,
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[],
response=LLMResult(generations=[[]]),
@@ -260,8 +213,8 @@ def test_tracer_multiple_llm_runs() -> None:
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@@ -269,12 +222,15 @@ def test_tracer_multiple_llm_runs() -> None:
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = str(uuid4())
compare_run = ChainRun(
id=None,
uuid=uuid,
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
inputs={},
outputs={},
@@ -284,20 +240,23 @@ def test_tracer_chain_run() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_chain_end(outputs={})
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
tracer.on_chain_end(outputs={}, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = str(uuid4())
compare_run = ToolRun(
id=None,
uuid=uuid,
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
tool_input="test",
output="test",
@@ -308,8 +267,8 @@ def test_tracer_tool_run() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_end("test")
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
tracer.on_tool_end("test", run_id=uuid)
assert tracer.runs == [compare_run]
@@ -318,21 +277,24 @@ def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
tracer.new_session()
_perform_nested_run(tracer)
assert tracer.runs == [_get_compare_run()]
[_perform_nested_run(tracer) for _ in range(10)]
assert tracer.runs == [_get_compare_run()] * 10
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
uuid = str(uuid4())
compare_run = LLMRun(
id=None,
uuid=uuid,
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[],
response=None,
@@ -342,8 +304,8 @@ def test_tracer_llm_run_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@@ -351,13 +313,16 @@ def test_tracer_llm_run_on_error() -> None:
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
uuid = str(uuid4())
compare_run = ChainRun(
id=None,
uuid=uuid,
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
inputs={},
outputs=None,
@@ -367,8 +332,8 @@ def test_tracer_chain_run_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_chain_error(exception)
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@@ -376,13 +341,16 @@ def test_tracer_chain_run_on_error() -> None:
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
uuid = str(uuid4())
compare_run = ToolRun(
id=None,
uuid=uuid,
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
tool_input="test",
output=None,
@@ -393,8 +361,8 @@ def test_tracer_tool_run_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_error(exception)
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@@ -405,37 +373,53 @@ def test_tracer_nested_runs_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
chain_uuid = "chain_uuid"
tool_uuid = "tool_uuid"
llm_uuid1 = "llm_uuid1"
llm_uuid2 = "llm_uuid2"
llm_uuid3 = "llm_uuid3"
for _ in range(3):
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_tool_error(exception)
tracer.on_chain_error(exception)
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_tool_start(
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
)
tracer.on_llm_start(
serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid
)
tracer.on_llm_error(exception, run_id=llm_uuid3)
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = ChainRun(
id=None,
uuid=chain_uuid,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_runs=[
child_llm_runs=[
LLMRun(
id=None,
uuid=llm_uuid1,
parent_uuid=chain_uuid,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=2,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
@@ -443,36 +427,45 @@ def test_tracer_nested_runs_on_error() -> None:
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
id=None,
uuid=llm_uuid2,
parent_uuid=chain_uuid,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
],
child_chain_runs=[],
child_tool_runs=[
ToolRun(
id=None,
uuid=tool_uuid,
parent_uuid=chain_uuid,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="{}",
child_runs=[
child_llm_runs=[
LLMRun(
id=None,
uuid=llm_uuid3,
parent_uuid=tool_uuid,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
child_execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
@@ -480,43 +473,10 @@ def test_tracer_nested_runs_on_error() -> None:
response=None,
)
],
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
),
],
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
)
assert tracer.runs == [compare_run] * 3
@freeze_time("2023-01-01")
def test_shared_tracer_nested_run() -> None:
"""Test shared tracer on a nested run."""
tracer = FakeSharedTracer()
tracer.new_session()
tracer.remove_runs()
_perform_nested_run(tracer)
assert tracer.runs == [_get_compare_run()]
@freeze_time("2023-01-01")
def test_shared_tracer_nested_run_multithreaded() -> None:
"""Test shared tracer on a nested run."""
tracer = FakeSharedTracer()
tracer.remove_runs()
tracer.new_session()
threads = []
num_threads = 10
for _ in range(num_threads):
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
assert tracer.runs == [_get_compare_run()] * num_threads

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.chains.base import Chain
from langchain.schema import BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

View File

@@ -1,5 +1,5 @@
"""Test LLM callbacks."""
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.manager import CallbackManager
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.llms.fake_llm import FakeLLM