mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Enable streaming for OpenAI LLM (#986)
* Support a callback `on_llm_new_token` that users can implement when `OpenAI.streaming` is set to `True`
This commit is contained in:
parent
f05f025e41
commit
caa8e4742e
@ -14,7 +14,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"id": "70c4e529",
|
"id": "70c4e529",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||||
@ -36,7 +38,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"id": "01c46e92",
|
"id": "01c46e92",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.document_loaders import TextLoader\n",
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
@ -56,7 +60,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"id": "433363a5",
|
"id": "433363a5",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# loaders = [....]\n",
|
"# loaders = [....]\n",
|
||||||
@ -75,9 +81,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "a8930cf7",
|
"id": "a8930cf7",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -106,9 +114,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 5,
|
||||||
"id": "7b4110f3",
|
"id": "7b4110f3",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore)"
|
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore)"
|
||||||
@ -126,7 +136,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"id": "7fe3e730",
|
"id": "7fe3e730",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chat_history = []\n",
|
"chat_history = []\n",
|
||||||
@ -136,9 +148,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 7,
|
||||||
"id": "bfff9cc8",
|
"id": "bfff9cc8",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
@ -146,7 +160,7 @@
|
|||||||
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
|
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -165,9 +179,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 8,
|
||||||
"id": "00b4cf00",
|
"id": "00b4cf00",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chat_history = [(query, result[\"answer\"])]\n",
|
"chat_history = [(query, result[\"answer\"])]\n",
|
||||||
@ -177,9 +193,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 9,
|
||||||
"id": "f01828d1",
|
"id": "f01828d1",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
@ -187,7 +205,7 @@
|
|||||||
"' Justice Stephen Breyer'"
|
"' Justice Stephen Breyer'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 11,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -196,10 +214,90 @@
|
|||||||
"result['answer']"
|
"result['answer']"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2324cdc6-98bf-4708-b8cd-02a98b1e5b67",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat Vector DB with streaming to `stdout`\n",
|
||||||
|
"\n",
|
||||||
|
"Output from the chain will be streamed to `stdout` token by token in this example."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "2efacec3-2690-4b05-8de3-a32fd2ac3911",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains.llm import LLMChain\n",
|
||||||
|
"from langchain.callbacks.base import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||||
|
"from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
|
||||||
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||||
|
"\n",
|
||||||
|
"# Construct a ChatVectorDBChain with a streaming llm for combine docs\n",
|
||||||
|
"# and a separate, non-streaming llm for question generation\n",
|
||||||
|
"llm = OpenAI(temperature=0)\n",
|
||||||
|
"streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||||
|
"\n",
|
||||||
|
"question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n",
|
||||||
|
"doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n",
|
||||||
|
"\n",
|
||||||
|
"qa = ChatVectorDBChain(vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=question_generator)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "fd6d43f4-7428-44a4-81bc-26fe88a98762",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat_history = []\n",
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "5ab38978-f3e8-4fa7-808c-c79dec48379a",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" Justice Stephen Breyer"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat_history = [(query, result[\"answer\"])]\n",
|
||||||
|
"query = \"Did he mention who she suceeded\"\n",
|
||||||
|
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "d0f869c6",
|
"id": "a7ea93ff-1899-4171-9c24-85df20ae1a3d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
@ -221,7 +319,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.10.9"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -18,7 +18,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"id": "df924055",
|
"id": "df924055",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.llms import OpenAI"
|
"from langchain.llms import OpenAI"
|
||||||
@ -207,14 +209,6 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"llm.get_num_tokens(\"what a joke\")"
|
"llm.get_num_tokens(\"what a joke\")"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "b004ffdd",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -8,6 +8,7 @@ They are split into two categories:
|
|||||||
1. `Generic Functionality <./generic_how_to.html>`_: Covering generic functionality all LLMs should have.
|
1. `Generic Functionality <./generic_how_to.html>`_: Covering generic functionality all LLMs should have.
|
||||||
2. `Integrations <./integrations.html>`_: Covering integrations with various LLM providers.
|
2. `Integrations <./integrations.html>`_: Covering integrations with various LLM providers.
|
||||||
3. `Asynchronous <./async_llm.html>`_: Covering asynchronous functionality.
|
3. `Asynchronous <./async_llm.html>`_: Covering asynchronous functionality.
|
||||||
|
4. `Streaming <./streaming_llm.html>`_: Covering streaming functionality.
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
140
docs/modules/llms/streaming_llm.ipynb
Normal file
140
docs/modules/llms/streaming_llm.ipynb
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "6eaf7e66-f49c-42da-8d11-22ea13bef718",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Streaming with LLMs\n",
|
||||||
|
"\n",
|
||||||
|
"LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Verse 1\n",
|
||||||
|
"I'm sippin' on sparkling water,\n",
|
||||||
|
"It's so refreshing and light,\n",
|
||||||
|
"It's the perfect way to quench my thirst,\n",
|
||||||
|
"On a hot summer night.\n",
|
||||||
|
"\n",
|
||||||
|
"Chorus\n",
|
||||||
|
"Sparkling water, sparkling water,\n",
|
||||||
|
"It's the best way to stay hydrated,\n",
|
||||||
|
"It's so refreshing and light,\n",
|
||||||
|
"It's the perfect way to stay alive.\n",
|
||||||
|
"\n",
|
||||||
|
"Verse 2\n",
|
||||||
|
"I'm sippin' on sparkling water,\n",
|
||||||
|
"It's so bubbly and bright,\n",
|
||||||
|
"It's the perfect way to cool me down,\n",
|
||||||
|
"On a hot summer night.\n",
|
||||||
|
"\n",
|
||||||
|
"Chorus\n",
|
||||||
|
"Sparkling water, sparkling water,\n",
|
||||||
|
"It's the best way to stay hydrated,\n",
|
||||||
|
"It's so refreshing and light,\n",
|
||||||
|
"It's the perfect way to stay alive.\n",
|
||||||
|
"\n",
|
||||||
|
"Verse 3\n",
|
||||||
|
"I'm sippin' on sparkling water,\n",
|
||||||
|
"It's so crisp and clean,\n",
|
||||||
|
"It's the perfect way to keep me going,\n",
|
||||||
|
"On a hot summer day.\n",
|
||||||
|
"\n",
|
||||||
|
"Chorus\n",
|
||||||
|
"Sparkling water, sparkling water,\n",
|
||||||
|
"It's the best way to stay hydrated,\n",
|
||||||
|
"It's so refreshing and light,\n",
|
||||||
|
"It's the perfect way to stay alive."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"from langchain.callbacks.base import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||||
|
"resp = llm(\"Write me a song about sparkling water.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "61fb6de7-c6c8-48d0-a48e-1204c027a23c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"We still have access to the end `LLMResult` if using `generate`. However, `token_usage` is not currently supported for streaming."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "a35373f1-9ee6-4753-a343-5aee749b8527",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Q: What did the fish say when it hit the wall?\n",
|
||||||
|
"A: Dam!"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {}})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm.generate([\"Tell me a joke.\"])"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -375,6 +375,22 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
final_output["intermediate_steps"] = intermediate_steps
|
final_output["intermediate_steps"] = intermediate_steps
|
||||||
return final_output
|
return final_output
|
||||||
|
|
||||||
|
async def _areturn(
|
||||||
|
self, output: AgentFinish, intermediate_steps: list
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_agent_finish(
|
||||||
|
output, color="green", verbose=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_agent_finish(
|
||||||
|
output, color="green", verbose=self.verbose
|
||||||
|
)
|
||||||
|
final_output = output.return_values
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
final_output["intermediate_steps"] = intermediate_steps
|
||||||
|
return final_output
|
||||||
|
|
||||||
def _take_next_step(
|
def _take_next_step(
|
||||||
self,
|
self,
|
||||||
name_to_tool_map: Dict[str, Tool],
|
name_to_tool_map: Dict[str, Tool],
|
||||||
@ -428,6 +444,90 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
return AgentFinish({self.agent.return_values[0]: observation}, "")
|
return AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||||
return output, observation
|
return output, observation
|
||||||
|
|
||||||
|
async def _atake_next_step(
|
||||||
|
self,
|
||||||
|
name_to_tool_map: Dict[str, Tool],
|
||||||
|
color_mapping: Dict[str, str],
|
||||||
|
inputs: Dict[str, str],
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
) -> Union[AgentFinish, Tuple[AgentAction, str]]:
|
||||||
|
"""Take a single step in the thought-action-observation loop.
|
||||||
|
|
||||||
|
Override this to take control of how the agent makes and acts on choices.
|
||||||
|
"""
|
||||||
|
# Call the LLM to see what to do.
|
||||||
|
output = await self.agent.aplan(intermediate_steps, **inputs)
|
||||||
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
|
if isinstance(output, AgentFinish):
|
||||||
|
return output
|
||||||
|
# Otherwise we lookup the tool
|
||||||
|
if output.tool in name_to_tool_map:
|
||||||
|
tool = name_to_tool_map[output.tool]
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_tool_start(
|
||||||
|
{"name": str(tool.func)[:60] + "..."},
|
||||||
|
output,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_tool_start(
|
||||||
|
{"name": str(tool.func)[:60] + "..."},
|
||||||
|
output,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# We then call the tool on the tool input to get an observation
|
||||||
|
observation = (
|
||||||
|
await tool.coroutine(output.tool_input)
|
||||||
|
if tool.coroutine
|
||||||
|
# If the tool is not a coroutine, we run it in the executor
|
||||||
|
# to avoid blocking the event loop.
|
||||||
|
else await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, tool.func, output.tool_input
|
||||||
|
)
|
||||||
|
)
|
||||||
|
color = color_mapping[output.tool]
|
||||||
|
return_direct = tool.return_direct
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_tool_error(e, verbose=self.verbose)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_tool_error(e, verbose=self.verbose)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_tool_start(
|
||||||
|
{"name": "N/A"}, output, verbose=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_tool_start(
|
||||||
|
{"name": "N/A"}, output, verbose=self.verbose
|
||||||
|
)
|
||||||
|
observation = f"{output.tool} is not a valid tool, try another one."
|
||||||
|
color = None
|
||||||
|
return_direct = False
|
||||||
|
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_tool_end(
|
||||||
|
observation,
|
||||||
|
color=color,
|
||||||
|
observation_prefix=self.agent.observation_prefix,
|
||||||
|
llm_prefix=llm_prefix,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_tool_end(
|
||||||
|
observation,
|
||||||
|
color=color,
|
||||||
|
observation_prefix=self.agent.observation_prefix,
|
||||||
|
llm_prefix=llm_prefix,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
if return_direct:
|
||||||
|
# Set the log to "" because we do not want to log it.
|
||||||
|
return AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||||
|
return output, observation
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||||
"""Run text through and get agent response."""
|
"""Run text through and get agent response."""
|
||||||
# Make sure that every tool is synchronous (not a coroutine)
|
# Make sure that every tool is synchronous (not a coroutine)
|
||||||
@ -486,58 +586,15 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
iterations = 0
|
iterations = 0
|
||||||
# We now enter the agent loop (until it returns something).
|
# We now enter the agent loop (until it returns something).
|
||||||
while self._should_continue(iterations):
|
while self._should_continue(iterations):
|
||||||
# Call the LLM to see what to do.
|
next_step_output = await self._atake_next_step(
|
||||||
output = await self.agent.aplan(intermediate_steps, **inputs)
|
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||||
# If the tool chosen is the finishing tool, then we end and return.
|
)
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(next_step_output, AgentFinish):
|
||||||
return self._return(output, intermediate_steps)
|
return await self._areturn(next_step_output, intermediate_steps)
|
||||||
|
|
||||||
# Otherwise we lookup the tool
|
intermediate_steps.append(next_step_output)
|
||||||
if output.tool in name_to_tool_map:
|
|
||||||
tool = name_to_tool_map[output.tool]
|
|
||||||
self.callback_manager.on_tool_start(
|
|
||||||
{"name": str(tool.func)[:60] + "..."},
|
|
||||||
output,
|
|
||||||
verbose=self.verbose,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# We then call the tool on the tool input to get an observation
|
|
||||||
observation = (
|
|
||||||
await tool.coroutine(output.tool_input)
|
|
||||||
if tool.coroutine
|
|
||||||
# If the tool is not a coroutine, we run it in the executor
|
|
||||||
# to avoid blocking the event loop.
|
|
||||||
else await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, tool.func, output.tool_input
|
|
||||||
)
|
|
||||||
)
|
|
||||||
color = color_mapping[output.tool]
|
|
||||||
return_direct = tool.return_direct
|
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
|
||||||
self.callback_manager.on_tool_error(e, verbose=self.verbose)
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_tool_start(
|
|
||||||
{"name": "N/A"}, output, verbose=self.verbose
|
|
||||||
)
|
|
||||||
observation = f"{output.tool} is not a valid tool, try another one."
|
|
||||||
color = None
|
|
||||||
return_direct = False
|
|
||||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
|
||||||
self.callback_manager.on_tool_end(
|
|
||||||
observation,
|
|
||||||
color=color,
|
|
||||||
observation_prefix=self.agent.observation_prefix,
|
|
||||||
llm_prefix=llm_prefix,
|
|
||||||
verbose=self.verbose,
|
|
||||||
)
|
|
||||||
intermediate_steps.append((output, observation))
|
|
||||||
if return_direct:
|
|
||||||
# Set the log to "" because we do not want to log it.
|
|
||||||
output = AgentFinish({self.agent.return_values[0]: observation}, "")
|
|
||||||
return self._return(output, intermediate_steps)
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
output = self.agent.return_stopped_response(
|
output = self.agent.return_stopped_response(
|
||||||
self.early_stopping_method, intermediate_steps, **inputs
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
)
|
)
|
||||||
return self._return(output, intermediate_steps)
|
return await self._areturn(output, intermediate_steps)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
@ -32,63 +33,72 @@ class BaseCallbackHandler(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> Any:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
||||||
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> Any:
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> Any:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> Any:
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
) -> None:
|
) -> Any:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> Any:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||||
"""Run on arbitrary text."""
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_async(self) -> bool:
|
||||||
|
"""Whether the callback manager is async."""
|
||||||
|
return False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
"""Add a handler to the callback manager."""
|
"""Add a handler to the callback manager."""
|
||||||
@ -126,6 +136,15 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
if verbose or handler.always_verbose:
|
if verbose or handler.always_verbose:
|
||||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
|
def on_llm_new_token(
|
||||||
|
self, token: str, verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM generates a new token."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
handler.on_llm_new_token(token, **kwargs)
|
||||||
|
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -239,3 +258,287 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||||
"""Set handlers as the only handlers on the callback manager."""
|
"""Set handlers as the only handlers on the callback manager."""
|
||||||
self.handlers = handlers
|
self.handlers = handlers
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
async def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
|
|
||||||
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
|
||||||
|
async def on_llm_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
|
||||||
|
async def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
|
||||||
|
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
|
||||||
|
async def on_chain_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
|
||||||
|
async def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
|
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
|
async def on_tool_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
|
||||||
|
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
|
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCallbackManager(BaseCallbackManager):
|
||||||
|
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_async(self) -> bool:
|
||||||
|
"""Return whether the handler is async."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||||
|
"""Initialize callback manager."""
|
||||||
|
self.handlers: List[BaseCallbackHandler] = handlers
|
||||||
|
|
||||||
|
async def on_llm_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
prompts: List[str],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_llm_start):
|
||||||
|
await handler.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(
|
||||||
|
handler.on_llm_start, serialized, prompts, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_llm_new_token(
|
||||||
|
self, token: str, verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_llm_new_token):
|
||||||
|
await handler.on_llm_new_token(token, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(
|
||||||
|
handler.on_llm_new_token, token, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_llm_end(
|
||||||
|
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_llm_end):
|
||||||
|
await handler.on_llm_end(response, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(handler.on_llm_end, response, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_llm_error(
|
||||||
|
self,
|
||||||
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_llm_error):
|
||||||
|
await handler.on_llm_error(error, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(handler.on_llm_error, error, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_chain_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_chain_start):
|
||||||
|
await handler.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(
|
||||||
|
handler.on_chain_start, serialized, inputs, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_chain_end(
|
||||||
|
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_chain_end):
|
||||||
|
await handler.on_chain_end(outputs, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(handler.on_chain_end, outputs, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_chain_error(
|
||||||
|
self,
|
||||||
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_chain_error):
|
||||||
|
await handler.on_chain_error(error, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(handler.on_chain_error, error, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_tool_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
action: AgentAction,
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_tool_start):
|
||||||
|
await handler.on_tool_start(serialized, action, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(
|
||||||
|
handler.on_tool_start, serialized, action, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_tool_end(
|
||||||
|
self, output: str, verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_tool_end):
|
||||||
|
await handler.on_tool_end(output, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(handler.on_tool_end, output, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_tool_error(
|
||||||
|
self,
|
||||||
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_tool_error):
|
||||||
|
await handler.on_tool_error(error, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(handler.on_tool_error, error, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||||
|
"""Run when text is printed."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_text):
|
||||||
|
await handler.on_text(text, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, functools.partial(handler.on_text, text, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_agent_finish(
|
||||||
|
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when agent finishes."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_agent_finish):
|
||||||
|
await handler.on_agent_finish(finish, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(
|
||||||
|
handler.on_agent_finish, finish, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a handler to the callback manager."""
|
||||||
|
self.handlers.append(handler)
|
||||||
|
|
||||||
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a handler from the callback manager."""
|
||||||
|
self.handlers.remove(handler)
|
||||||
|
|
||||||
|
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||||
|
"""Set handlers as the only handlers on the callback manager."""
|
||||||
|
self.handlers = handlers
|
||||||
|
@ -21,8 +21,12 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
|||||||
"""Print out the prompts."""
|
"""Print out the prompts."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Print out the token."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Do nothing."""
|
"""Collect token usage."""
|
||||||
if response.llm_output is not None:
|
if response.llm_output is not None:
|
||||||
if "token_usage" in response.llm_output:
|
if "token_usage" in response.llm_output:
|
||||||
token_usage = response.llm_output["token_usage"]
|
token_usage = response.llm_output["token_usage"]
|
||||||
|
@ -46,6 +46,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_llm_end(response, **kwargs)
|
self._callback_manager.on_llm_end(response, **kwargs)
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when LLM generates a new token."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_new_token(token, **kwargs)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -23,6 +23,10 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
60
langchain/callbacks/streaming_stdout.py
Normal file
60
langchain/callbacks/streaming_stdout.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Callback Handler streams to stdout on new llm token."""
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
|
sys.stdout.write(token)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
|
||||||
|
def on_llm_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
|
||||||
|
def on_chain_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
|
def on_tool_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
@ -18,6 +18,10 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
st.write(prompt)
|
st.write(prompt)
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Do nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
@ -129,6 +129,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._start_trace(llm_run)
|
self._start_trace(llm_run)
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Handle a new token for an LLM run."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""End a trace for an LLM run."""
|
"""End a trace for an LLM run."""
|
||||||
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
if not self._stack or not isinstance(self._stack[-1], LLMRun):
|
||||||
|
@ -158,6 +158,13 @@ class Chain(BaseModel, ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
inputs = self.prep_inputs(inputs)
|
inputs = self.prep_inputs(inputs)
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_chain_start(
|
||||||
|
{"name": self.__class__.__name__},
|
||||||
|
inputs,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.callback_manager.on_chain_start(
|
self.callback_manager.on_chain_start(
|
||||||
{"name": self.__class__.__name__},
|
{"name": self.__class__.__name__},
|
||||||
inputs,
|
inputs,
|
||||||
@ -166,8 +173,14 @@ class Chain(BaseModel, ABC):
|
|||||||
try:
|
try:
|
||||||
outputs = await self._acall(inputs)
|
outputs = await self._acall(inputs)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||||
|
else:
|
||||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||||
raise e
|
raise e
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||||
|
else:
|
||||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||||
|
|
||||||
|
@ -83,3 +83,20 @@ class ChatVectorDBChain(Chain, BaseModel):
|
|||||||
new_inputs["chat_history"] = chat_history_str
|
new_inputs["chat_history"] = chat_history_str
|
||||||
answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs)
|
answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs)
|
||||||
return {self.output_key: answer}
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
question = inputs["question"]
|
||||||
|
chat_history_str = _get_chat_history(inputs["chat_history"])
|
||||||
|
if chat_history_str:
|
||||||
|
new_question = await self.question_generator.arun(
|
||||||
|
question=question, chat_history=chat_history_str
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_question = question
|
||||||
|
# TODO: This blocks the event loop, but it's not clear how to avoid it.
|
||||||
|
docs = self.vectorstore.similarity_search(new_question, k=4)
|
||||||
|
new_inputs = inputs.copy()
|
||||||
|
new_inputs["question"] = new_question
|
||||||
|
new_inputs["chat_history"] = chat_history_str
|
||||||
|
answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs)
|
||||||
|
return {self.output_key: answer}
|
||||||
|
@ -43,6 +43,12 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
|
|||||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||||
"""Combine documents into a single string."""
|
"""Combine documents into a single string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def acombine_docs(
|
||||||
|
self, docs: List[Document], **kwargs: Any
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Combine documents into a single string asynchronously."""
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
docs = inputs[self.input_key]
|
docs = inputs[self.input_key]
|
||||||
# Other keys are assumed to be needed for LLM prediction
|
# Other keys are assumed to be needed for LLM prediction
|
||||||
@ -51,6 +57,14 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
|
|||||||
extra_return_dict[self.output_key] = output
|
extra_return_dict[self.output_key] = output
|
||||||
return extra_return_dict
|
return extra_return_dict
|
||||||
|
|
||||||
|
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
docs = inputs[self.input_key]
|
||||||
|
# Other keys are assumed to be needed for LLM prediction
|
||||||
|
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||||
|
output, extra_return_dict = await self.acombine_docs(docs, **other_keys)
|
||||||
|
extra_return_dict[self.output_key] = output
|
||||||
|
return extra_return_dict
|
||||||
|
|
||||||
|
|
||||||
class AnalyzeDocumentChain(Chain, BaseModel):
|
class AnalyzeDocumentChain(Chain, BaseModel):
|
||||||
"""Chain that splits documents, then analyzes it in pieces."""
|
"""Chain that splits documents, then analyzes it in pieces."""
|
||||||
|
@ -140,6 +140,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
# FYI - this is parallelized and so it is fast.
|
# FYI - this is parallelized and so it is fast.
|
||||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||||
)
|
)
|
||||||
|
return self._process_results(results, docs, token_max, **kwargs)
|
||||||
|
|
||||||
|
async def acombine_docs(
|
||||||
|
self, docs: List[Document], **kwargs: Any
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Combine documents in a map reduce manner.
|
||||||
|
|
||||||
|
Combine by mapping first chain over all documents, then reducing the results.
|
||||||
|
This reducing can be done recursively if needed (if there are many documents).
|
||||||
|
"""
|
||||||
|
results = await self.llm_chain.aapply(
|
||||||
|
# FYI - this is parallelized and so it is fast.
|
||||||
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||||
|
)
|
||||||
|
return self._process_results(results, docs, **kwargs)
|
||||||
|
|
||||||
|
def _process_results(
|
||||||
|
self,
|
||||||
|
results: List[Dict],
|
||||||
|
docs: List[Document],
|
||||||
|
token_max: int = 3000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
question_result_key = self.llm_chain.output_key
|
question_result_key = self.llm_chain.output_key
|
||||||
result_docs = [
|
result_docs = [
|
||||||
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
@ -98,8 +98,27 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
# FYI - this is parallelized and so it is fast.
|
# FYI - this is parallelized and so it is fast.
|
||||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||||
)
|
)
|
||||||
typed_results = cast(List[dict], results)
|
return self._process_results(docs, results)
|
||||||
|
|
||||||
|
async def acombine_docs(
|
||||||
|
self, docs: List[Document], **kwargs: Any
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Combine documents in a map rerank manner.
|
||||||
|
|
||||||
|
Combine by mapping first chain over all documents, then reranking the results.
|
||||||
|
"""
|
||||||
|
results = await self.llm_chain.aapply_and_parse(
|
||||||
|
# FYI - this is parallelized and so it is fast.
|
||||||
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||||
|
)
|
||||||
|
return self._process_results(docs, results)
|
||||||
|
|
||||||
|
def _process_results(
|
||||||
|
self,
|
||||||
|
docs: List[Document],
|
||||||
|
results: Sequence[Union[str, List[str], Dict[str, str]]],
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
typed_results = cast(List[dict], results)
|
||||||
sorted_res = sorted(
|
sorted_res = sorted(
|
||||||
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
|
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
|
||||||
)
|
)
|
||||||
|
@ -84,6 +84,50 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
|
|
||||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||||
|
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||||
|
res = self.initial_llm_chain.predict(**inputs)
|
||||||
|
refine_steps = [res]
|
||||||
|
for doc in docs[1:]:
|
||||||
|
base_inputs = self._construct_refine_inputs(doc, res)
|
||||||
|
inputs = {**base_inputs, **kwargs}
|
||||||
|
res = self.refine_llm_chain.predict(**inputs)
|
||||||
|
refine_steps.append(res)
|
||||||
|
return self._construct_result(refine_steps, res)
|
||||||
|
|
||||||
|
async def acombine_docs(
|
||||||
|
self, docs: List[Document], **kwargs: Any
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||||
|
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||||
|
res = await self.initial_llm_chain.apredict(**inputs)
|
||||||
|
refine_steps = [res]
|
||||||
|
for doc in docs[1:]:
|
||||||
|
base_inputs = self._construct_refine_inputs(doc, res)
|
||||||
|
inputs = {**base_inputs, **kwargs}
|
||||||
|
res = await self.refine_llm_chain.apredict(**inputs)
|
||||||
|
refine_steps.append(res)
|
||||||
|
return self._construct_result(refine_steps, res)
|
||||||
|
|
||||||
|
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
extra_return_dict = {"intermediate_steps": refine_steps}
|
||||||
|
else:
|
||||||
|
extra_return_dict = {}
|
||||||
|
return res, extra_return_dict
|
||||||
|
|
||||||
|
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
||||||
|
base_info = {"page_content": doc.page_content}
|
||||||
|
base_info.update(doc.metadata)
|
||||||
|
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||||
|
base_inputs = {
|
||||||
|
self.document_variable_name: self.document_prompt.format(**document_info),
|
||||||
|
self.initial_response_name: res,
|
||||||
|
}
|
||||||
|
return base_inputs
|
||||||
|
|
||||||
|
def _construct_initial_inputs(
|
||||||
|
self, docs: List[Document], **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
base_info = {"page_content": docs[0].page_content}
|
base_info = {"page_content": docs[0].page_content}
|
||||||
base_info.update(docs[0].metadata)
|
base_info.update(docs[0].metadata)
|
||||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||||
@ -91,28 +135,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
self.document_variable_name: self.document_prompt.format(**document_info)
|
self.document_variable_name: self.document_prompt.format(**document_info)
|
||||||
}
|
}
|
||||||
inputs = {**base_inputs, **kwargs}
|
inputs = {**base_inputs, **kwargs}
|
||||||
res = self.initial_llm_chain.predict(**inputs)
|
return inputs
|
||||||
refine_steps = [res]
|
|
||||||
for doc in docs[1:]:
|
|
||||||
base_info = {"page_content": doc.page_content}
|
|
||||||
base_info.update(doc.metadata)
|
|
||||||
document_info = {
|
|
||||||
k: base_info[k] for k in self.document_prompt.input_variables
|
|
||||||
}
|
|
||||||
base_inputs = {
|
|
||||||
self.document_variable_name: self.document_prompt.format(
|
|
||||||
**document_info
|
|
||||||
),
|
|
||||||
self.initial_response_name: res,
|
|
||||||
}
|
|
||||||
inputs = {**base_inputs, **kwargs}
|
|
||||||
res = self.refine_llm_chain.predict(**inputs)
|
|
||||||
refine_steps.append(res)
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
extra_return_dict = {"intermediate_steps": refine_steps}
|
|
||||||
else:
|
|
||||||
extra_return_dict = {}
|
|
||||||
return res, extra_return_dict
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
|
@ -84,6 +84,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
|||||||
# Call predict on the LLM.
|
# Call predict on the LLM.
|
||||||
return self.llm_chain.predict(**inputs), {}
|
return self.llm_chain.predict(**inputs), {}
|
||||||
|
|
||||||
|
async def acombine_docs(
|
||||||
|
self, docs: List[Document], **kwargs: Any
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Stuff all documents into one prompt and pass to LLM."""
|
||||||
|
inputs = self._get_inputs(docs, **kwargs)
|
||||||
|
# Call predict on the LLM.
|
||||||
|
return await self.llm_chain.apredict(**inputs), {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "stuff_documents_chain"
|
return "stuff_documents_chain"
|
||||||
|
@ -61,7 +61,7 @@ class LLMChain(Chain, BaseModel):
|
|||||||
|
|
||||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
prompts, stop = self.prep_prompts(input_list)
|
prompts, stop = await self.aprep_prompts(input_list)
|
||||||
response = await self.llm.agenerate(prompts, stop=stop)
|
response = await self.llm.agenerate(prompts, stop=stop)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -86,6 +86,32 @@ class LLMChain(Chain, BaseModel):
|
|||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
return prompts, stop
|
return prompts, stop
|
||||||
|
|
||||||
|
async def aprep_prompts(
|
||||||
|
self, input_list: List[Dict[str, Any]]
|
||||||
|
) -> Tuple[List[str], Optional[List[str]]]:
|
||||||
|
"""Prepare prompts from inputs."""
|
||||||
|
stop = None
|
||||||
|
if "stop" in input_list[0]:
|
||||||
|
stop = input_list[0]["stop"]
|
||||||
|
prompts = []
|
||||||
|
for inputs in input_list:
|
||||||
|
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||||
|
prompt = self.prompt.format(**selected_inputs)
|
||||||
|
_colored_text = get_colored_text(prompt, "green")
|
||||||
|
_text = "Prompt after formatting:\n" + _colored_text
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_text(
|
||||||
|
_text, end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||||
|
if "stop" in inputs and inputs["stop"] != stop:
|
||||||
|
raise ValueError(
|
||||||
|
"If `stop` is present in any inputs, should be present in all."
|
||||||
|
)
|
||||||
|
prompts.append(prompt)
|
||||||
|
return prompts, stop
|
||||||
|
|
||||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||||
"""Utilize the LLM generate method for speed gains."""
|
"""Utilize the LLM generate method for speed gains."""
|
||||||
response = self.generate(input_list)
|
response = self.generate(input_list)
|
||||||
@ -156,6 +182,11 @@ class LLMChain(Chain, BaseModel):
|
|||||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||||
"""Call apply and then parse the results."""
|
"""Call apply and then parse the results."""
|
||||||
result = self.apply(input_list)
|
result = self.apply(input_list)
|
||||||
|
return self._parse_result(result)
|
||||||
|
|
||||||
|
def _parse_result(
|
||||||
|
self, result: List[Dict[str, str]]
|
||||||
|
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||||
if self.prompt.output_parser is not None:
|
if self.prompt.output_parser is not None:
|
||||||
new_result = []
|
new_result = []
|
||||||
for res in result:
|
for res in result:
|
||||||
@ -165,6 +196,13 @@ class LLMChain(Chain, BaseModel):
|
|||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def aapply_and_parse(
|
||||||
|
self, input_list: List[Dict[str, Any]]
|
||||||
|
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||||
|
"""Call apply and then parse the results."""
|
||||||
|
result = await self.aapply(input_list)
|
||||||
|
return self._parse_result(result)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "llm_chain"
|
return "llm_chain"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Load question answering chains."""
|
"""Load question answering chains."""
|
||||||
from typing import Any, Mapping, Optional, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||||
@ -31,14 +32,19 @@ def _load_map_rerank_chain(
|
|||||||
document_variable_name: str = "context",
|
document_variable_name: str = "context",
|
||||||
rank_key: str = "score",
|
rank_key: str = "score",
|
||||||
answer_key: str = "answer",
|
answer_key: str = "answer",
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapRerankDocumentsChain:
|
) -> MapRerankDocumentsChain:
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
llm_chain = LLMChain(
|
||||||
|
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
|
||||||
|
)
|
||||||
return MapRerankDocumentsChain(
|
return MapRerankDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
rank_key=rank_key,
|
rank_key=rank_key,
|
||||||
answer_key=answer_key,
|
answer_key=answer_key,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,14 +54,18 @@ def _load_stuff_chain(
|
|||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
document_variable_name: str = "context",
|
document_variable_name: str = "context",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> StuffDocumentsChain:
|
) -> StuffDocumentsChain:
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
llm_chain = LLMChain(
|
||||||
|
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
|
||||||
|
)
|
||||||
# TODO: document prompt
|
# TODO: document prompt
|
||||||
return StuffDocumentsChain(
|
return StuffDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,16 +80,28 @@ def _load_map_reduce_chain(
|
|||||||
reduce_llm: Optional[BaseLLM] = None,
|
reduce_llm: Optional[BaseLLM] = None,
|
||||||
collapse_llm: Optional[BaseLLM] = None,
|
collapse_llm: Optional[BaseLLM] = None,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
map_chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=question_prompt,
|
||||||
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
_reduce_llm = reduce_llm or llm
|
_reduce_llm = reduce_llm or llm
|
||||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
reduce_chain = LLMChain(
|
||||||
|
llm=_reduce_llm,
|
||||||
|
prompt=combine_prompt,
|
||||||
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
# TODO: document prompt
|
# TODO: document prompt
|
||||||
combine_document_chain = StuffDocumentsChain(
|
combine_document_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
)
|
)
|
||||||
if collapse_prompt is None:
|
if collapse_prompt is None:
|
||||||
collapse_chain = None
|
collapse_chain = None
|
||||||
@ -95,8 +117,11 @@ def _load_map_reduce_chain(
|
|||||||
llm=_collapse_llm,
|
llm=_collapse_llm,
|
||||||
prompt=collapse_prompt,
|
prompt=collapse_prompt,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
),
|
),
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
)
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
@ -104,6 +129,7 @@ def _load_map_reduce_chain(
|
|||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
collapse_document_chain=collapse_chain,
|
collapse_document_chain=collapse_chain,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -116,17 +142,29 @@ def _load_refine_chain(
|
|||||||
initial_response_name: str = "existing_answer",
|
initial_response_name: str = "existing_answer",
|
||||||
refine_llm: Optional[BaseLLM] = None,
|
refine_llm: Optional[BaseLLM] = None,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RefineDocumentsChain:
|
) -> RefineDocumentsChain:
|
||||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
initial_chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=question_prompt,
|
||||||
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
_refine_llm = refine_llm or llm
|
_refine_llm = refine_llm or llm
|
||||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
refine_chain = LLMChain(
|
||||||
|
llm=_refine_llm,
|
||||||
|
prompt=refine_prompt,
|
||||||
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
return RefineDocumentsChain(
|
return RefineDocumentsChain(
|
||||||
initial_llm_chain=initial_chain,
|
initial_llm_chain=initial_chain,
|
||||||
refine_llm_chain=refine_chain,
|
refine_llm_chain=refine_chain,
|
||||||
document_variable_name=document_variable_name,
|
document_variable_name=document_variable_name,
|
||||||
initial_response_name=initial_response_name,
|
initial_response_name=initial_response_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
callback_manager=callback_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -135,6 +173,7 @@ def load_qa_chain(
|
|||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
chain_type: str = "stuff",
|
chain_type: str = "stuff",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load question answering chain.
|
"""Load question answering chain.
|
||||||
@ -145,6 +184,7 @@ def load_qa_chain(
|
|||||||
"map_reduce", and "refine".
|
"map_reduce", and "refine".
|
||||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||||
applies to all chains that make up the final chain.
|
applies to all chains that make up the final chain.
|
||||||
|
callback_manager: Callback manager to use for the chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A chain to use for question answering.
|
A chain to use for question answering.
|
||||||
@ -160,4 +200,6 @@ def load_qa_chain(
|
|||||||
f"Got unsupported chain type: {chain_type}. "
|
f"Got unsupported chain type: {chain_type}. "
|
||||||
f"Should be one of {loader_mapping.keys()}"
|
f"Should be one of {loader_mapping.keys()}"
|
||||||
)
|
)
|
||||||
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
|
return loader_mapping[chain_type](
|
||||||
|
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
|
||||||
|
)
|
||||||
|
@ -165,14 +165,25 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_start(
|
||||||
|
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.callback_manager.on_llm_start(
|
self.callback_manager.on_llm_start(
|
||||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = await self._agenerate(prompts, stop=stop)
|
output = await self._agenerate(prompts, stop=stop)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
|
else:
|
||||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
raise e
|
raise e
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||||
|
else:
|
||||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||||
return output
|
return output
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
@ -184,14 +195,31 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
missing_prompts,
|
missing_prompts,
|
||||||
) = get_prompts(params, prompts)
|
) = get_prompts(params, prompts)
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_start(
|
||||||
|
{"name": self.__class__.__name__},
|
||||||
|
missing_prompts,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.callback_manager.on_llm_start(
|
self.callback_manager.on_llm_start(
|
||||||
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
|
{"name": self.__class__.__name__},
|
||||||
|
missing_prompts,
|
||||||
|
verbose=self.verbose,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
new_results = await self._agenerate(missing_prompts, stop=stop)
|
new_results = await self._agenerate(missing_prompts, stop=stop)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
|
else:
|
||||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
raise e
|
raise e
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_end(
|
||||||
|
new_results, verbose=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
||||||
llm_output = update_cache(
|
llm_output = update_cache(
|
||||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||||
|
@ -42,6 +42,27 @@ def update_token_usage(
|
|||||||
token_usage[_key] += response["usage"][_key]
|
token_usage[_key] += response["usage"][_key]
|
||||||
|
|
||||||
|
|
||||||
|
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
||||||
|
"""Update response from the stream response."""
|
||||||
|
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
|
||||||
|
response["choices"][0]["finish_reason"] = stream_response["choices"][0][
|
||||||
|
"finish_reason"
|
||||||
|
]
|
||||||
|
response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
|
||||||
|
|
||||||
|
|
||||||
|
def _streaming_response_template() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "",
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class BaseOpenAI(BaseLLM, BaseModel):
|
class BaseOpenAI(BaseLLM, BaseModel):
|
||||||
"""Wrapper around OpenAI large language models.
|
"""Wrapper around OpenAI large language models.
|
||||||
|
|
||||||
@ -88,6 +109,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|||||||
"""Adjust the probability of specific tokens being generated."""
|
"""Adjust the probability of specific tokens being generated."""
|
||||||
max_retries: int = 6
|
max_retries: int = 6
|
||||||
"""Maximum number of retries to make when generating."""
|
"""Maximum number of retries to make when generating."""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -129,6 +152,10 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
"Please it install it with `pip install openai`."
|
"Please it install it with `pip install openai`."
|
||||||
)
|
)
|
||||||
|
if values["streaming"] and values["n"] > 1:
|
||||||
|
raise ValueError("Cannot stream results when n > 1.")
|
||||||
|
if values["streaming"] and values["best_of"] > 1:
|
||||||
|
raise ValueError("Cannot stream results when best_of > 1.")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -215,8 +242,24 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|||||||
# Includes prompt, completion, and total tokens used.
|
# Includes prompt, completion, and total tokens used.
|
||||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||||
for _prompts in sub_prompts:
|
for _prompts in sub_prompts:
|
||||||
|
if self.streaming:
|
||||||
|
if len(_prompts) > 1:
|
||||||
|
raise ValueError("Cannot stream results with multiple prompts.")
|
||||||
|
params["stream"] = True
|
||||||
|
response = _streaming_response_template()
|
||||||
|
for stream_resp in self.completion_with_retry(
|
||||||
|
prompt=_prompts, **params
|
||||||
|
):
|
||||||
|
self.callback_manager.on_llm_new_token(
|
||||||
|
stream_resp["choices"][0]["text"], verbose=self.verbose
|
||||||
|
)
|
||||||
|
_update_response(response, stream_resp)
|
||||||
|
choices.extend(response["choices"])
|
||||||
|
else:
|
||||||
response = self.completion_with_retry(prompt=_prompts, **params)
|
response = self.completion_with_retry(prompt=_prompts, **params)
|
||||||
choices.extend(response["choices"])
|
choices.extend(response["choices"])
|
||||||
|
if not self.streaming:
|
||||||
|
# Can't update token usage if streaming
|
||||||
update_token_usage(_keys, response, token_usage)
|
update_token_usage(_keys, response, token_usage)
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
@ -232,9 +275,29 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|||||||
# Includes prompt, completion, and total tokens used.
|
# Includes prompt, completion, and total tokens used.
|
||||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||||
for _prompts in sub_prompts:
|
for _prompts in sub_prompts:
|
||||||
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
if self.streaming:
|
||||||
|
if len(_prompts) > 1:
|
||||||
|
raise ValueError("Cannot stream results with multiple prompts.")
|
||||||
|
params["stream"] = True
|
||||||
|
response = _streaming_response_template()
|
||||||
|
async for stream_resp in await self.acompletion_with_retry(
|
||||||
|
prompt=_prompts, **params
|
||||||
|
):
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_new_token(
|
||||||
|
stream_resp["choices"][0]["text"], verbose=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_llm_new_token(
|
||||||
|
stream_resp["choices"][0]["text"], verbose=self.verbose
|
||||||
|
)
|
||||||
|
_update_response(response, stream_resp)
|
||||||
|
choices.extend(response["choices"])
|
||||||
|
else:
|
||||||
response = await self.acompletion_with_retry(prompt=_prompts, **params)
|
response = await self.acompletion_with_retry(prompt=_prompts, **params)
|
||||||
choices.extend(response["choices"])
|
choices.extend(response["choices"])
|
||||||
|
if not self.streaming:
|
||||||
|
# Can't update token usage if streaming
|
||||||
update_token_usage(_keys, response, token_usage)
|
update_token_usage(_keys, response, token_usage)
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
@ -304,6 +367,13 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|||||||
for token in generator:
|
for token in generator:
|
||||||
yield token
|
yield token
|
||||||
"""
|
"""
|
||||||
|
params = self.prep_streaming_params(stop)
|
||||||
|
generator = self.client.create(prompt=prompt, **params)
|
||||||
|
|
||||||
|
return generator
|
||||||
|
|
||||||
|
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||||
|
"""Prepare the params for streaming."""
|
||||||
params = self._invocation_params
|
params = self._invocation_params
|
||||||
if params["best_of"] != 1:
|
if params["best_of"] != 1:
|
||||||
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
||||||
@ -312,9 +382,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|||||||
raise ValueError("`stop` found in both the input and default params.")
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
params["stream"] = True
|
params["stream"] = True
|
||||||
generator = self.client.create(prompt=prompt, **params)
|
return params
|
||||||
|
|
||||||
return generator
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invocation_params(self) -> Dict[str, Any]:
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
|
@ -5,9 +5,11 @@ from typing import Generator
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
from langchain.llms.loading import load_llm
|
from langchain.llms.loading import load_llm
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
def test_openai_call() -> None:
|
def test_openai_call() -> None:
|
||||||
@ -77,9 +79,66 @@ def test_openai_streaming_error() -> None:
|
|||||||
llm.stream("I'm Pickle Rick")
|
llm.stream("I'm Pickle Rick")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_streaming_best_of_error() -> None:
|
||||||
|
"""Test validation for streaming fails if best_of is not 1."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenAI(best_of=2, streaming=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_streaming_n_error() -> None:
|
||||||
|
"""Test validation for streaming fails if n is not 1."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenAI(n=2, streaming=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_streaming_multiple_prompts_error() -> None:
|
||||||
|
"""Test validation for streaming fails if multiple prompts are given."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_streaming_call() -> None:
|
||||||
|
"""Test valid call to openai."""
|
||||||
|
llm = OpenAI(max_tokens=10, streaming=True)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_streaming_callback() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
llm = OpenAI(
|
||||||
|
max_tokens=10,
|
||||||
|
streaming=True,
|
||||||
|
temperature=0,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
llm("Write me a sentence with 100 words.")
|
||||||
|
assert callback_handler.llm_streams == 10
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openai_async_generate() -> None:
|
async def test_openai_async_generate() -> None:
|
||||||
"""Test async generation."""
|
"""Test async generation."""
|
||||||
llm = OpenAI(max_tokens=10)
|
llm = OpenAI(max_tokens=10)
|
||||||
output = await llm.agenerate(["Hello, how are you?"])
|
output = await llm.agenerate(["Hello, how are you?"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_async_streaming_callback() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
llm = OpenAI(
|
||||||
|
max_tokens=10,
|
||||||
|
streaming=True,
|
||||||
|
temperature=0,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
||||||
|
assert callback_handler.llm_streams == 10
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
@ -3,12 +3,12 @@ from typing import Any, Dict, List, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
class BaseFakeCallbackHandler(BaseModel):
|
||||||
"""Fake callback handler for testing."""
|
"""Base fake callback handler for testing."""
|
||||||
|
|
||||||
starts: int = 0
|
starts: int = 0
|
||||||
ends: int = 0
|
ends: int = 0
|
||||||
@ -44,10 +44,15 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
|||||||
chain_ends: int = 0
|
chain_ends: int = 0
|
||||||
llm_starts: int = 0
|
llm_starts: int = 0
|
||||||
llm_ends: int = 0
|
llm_ends: int = 0
|
||||||
|
llm_streams: int = 0
|
||||||
tool_starts: int = 0
|
tool_starts: int = 0
|
||||||
tool_ends: int = 0
|
tool_ends: int = 0
|
||||||
agent_ends: int = 0
|
agent_ends: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||||
|
"""Fake callback handler for testing."""
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -55,6 +60,10 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
|||||||
self.llm_starts += 1
|
self.llm_starts += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when LLM generates a new token."""
|
||||||
|
self.llm_streams += 1
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running."""
|
||||||
self.llm_ends += 1
|
self.llm_ends += 1
|
||||||
@ -110,3 +119,74 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
|||||||
"""Run when agent ends running."""
|
"""Run when agent ends running."""
|
||||||
self.agent_ends += 1
|
self.agent_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||||
|
"""Fake async callback handler for testing."""
|
||||||
|
|
||||||
|
async def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
self.llm_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when LLM generates a new token."""
|
||||||
|
self.llm_streams += 1
|
||||||
|
|
||||||
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
self.llm_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
async def on_llm_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
async def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
self.chain_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
self.chain_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
async def on_chain_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
async def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
self.tool_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
self.tool_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
async def on_tool_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent is ending."""
|
||||||
|
self.text += 1
|
||||||
|
|
||||||
|
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent ends running."""
|
||||||
|
self.agent_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
@ -1,13 +1,24 @@
|
|||||||
"""Test CallbackManager."""
|
"""Test CallbackManager."""
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks.base import (
|
||||||
|
AsyncCallbackManager,
|
||||||
|
BaseCallbackManager,
|
||||||
|
CallbackManager,
|
||||||
|
)
|
||||||
from langchain.callbacks.shared import SharedCallbackManager
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||||
|
BaseFakeCallbackHandler,
|
||||||
|
FakeAsyncCallbackHandler,
|
||||||
|
FakeCallbackHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_callback_manager(
|
def _test_callback_manager(
|
||||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the CallbackManager."""
|
"""Test the CallbackManager."""
|
||||||
manager.on_llm_start({}, [])
|
manager.on_llm_start({}, [])
|
||||||
@ -20,6 +31,27 @@ def _test_callback_manager(
|
|||||||
manager.on_tool_end("")
|
manager.on_tool_end("")
|
||||||
manager.on_tool_error(Exception())
|
manager.on_tool_error(Exception())
|
||||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||||
|
_check_num_calls(handlers)
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_callback_manager_async(
|
||||||
|
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||||
|
) -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
await manager.on_llm_start({}, [])
|
||||||
|
await manager.on_llm_end(LLMResult(generations=[]))
|
||||||
|
await manager.on_llm_error(Exception())
|
||||||
|
await manager.on_chain_start({"name": "foo"}, {})
|
||||||
|
await manager.on_chain_end({})
|
||||||
|
await manager.on_chain_error(Exception())
|
||||||
|
await manager.on_tool_start({}, AgentAction("", "", ""))
|
||||||
|
await manager.on_tool_end("")
|
||||||
|
await manager.on_tool_error(Exception())
|
||||||
|
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||||
|
_check_num_calls(handlers)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
if handler.always_verbose:
|
if handler.always_verbose:
|
||||||
assert handler.starts == 3
|
assert handler.starts == 3
|
||||||
@ -128,3 +160,21 @@ def test_shared_callback_manager() -> None:
|
|||||||
manager1.add_handler(handler1)
|
manager1.add_handler(handler1)
|
||||||
manager2.add_handler(handler2)
|
manager2.add_handler(handler2)
|
||||||
_test_callback_manager(manager1, handler1, handler2)
|
_test_callback_manager(manager1, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_callback_manager() -> None:
|
||||||
|
"""Test the AsyncCallbackManager."""
|
||||||
|
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||||
|
handler2 = FakeAsyncCallbackHandler()
|
||||||
|
manager = AsyncCallbackManager([handler1, handler2])
|
||||||
|
await _test_callback_manager_async(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_callback_manager_sync_handler() -> None:
|
||||||
|
"""Test the AsyncCallbackManager."""
|
||||||
|
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||||
|
handler2 = FakeAsyncCallbackHandler()
|
||||||
|
manager = AsyncCallbackManager([handler1, handler2])
|
||||||
|
await _test_callback_manager_async(manager, handler1, handler2)
|
||||||
|
Loading…
Reference in New Issue
Block a user